Files
ragflow/rag/utils/tts_cache.py
plind f169ab4b39 feat(tts): cache synthesized speech in Redis to avoid redundant calls (#14851)
## What problem does this PR solve?

Closes #12017.

TTS output is deterministic for a given `(model, text)` pair, so
re-running the same text through the same TTS model produces the same
bytes — yet `Canvas.tts` and `dialog_service.tts` re-synthesized on
every request. That's slow and wastes provider quota whenever the same
assistant response is replayed, shared across users, or repeated within
a session.

### Change

New helper `rag/utils/tts_cache.py` with `synthesize_with_cache(tts_mdl,
cleaned_text)`:

- **Key:** `tts:cache:{model_id}:{sha256(text)}` — separate namespace
per model, identical cleaned text reuses a single entry across both call
sites.
- **Value:** the hex-encoded audio blob both call sites already
returned. No format change for downstream consumers.
- **TTL:** 7 days by default, configurable via
`RAGFLOW_TTS_CACHE_TTL_SECONDS`.
- **Failure modes:** a Redis hiccup falls back to direct synthesis; a
failed synthesis still returns `None` (existing contract preserved).


[`Canvas.tts`](https://github.com/infiniflow/ragflow/blob/main/agent/canvas.py#L683-L724)
and
[`dialog_service.tts`](https://github.com/infiniflow/ragflow/blob/main/api/db/services/dialog_service.py#L1367-L1380)
now route through the helper; the per-file bytes-accumulation/hex-encode
loop has been removed in favor of one shared implementation.

## Type of change

- [x] New Feature (non-breaking change which adds functionality)

## Test plan

- [ ] **Cache hit, chat path:** Configure a dialog with TTS enabled, ask
the same question twice with `stream=false`. Verify the second response
returns the same `audio_binary` and that the second invocation doesn't
hit the TTS provider (e.g., observe provider-side logs / usage counters;
check no `LLMBundle.tts can't update token usage` log line on the second
run).
- [ ] **Cache hit, agent path:** Same exercise via a Conversational
Agent that includes a Message component playing back the answer.
- [ ] **Cache isolation per model:** Switch tenant's `tts_id` between
two models, run the same text against each — confirm the second model's
first synthesis still happens (no cross-model hits).
- [ ] **TTL override:** Set `RAGFLOW_TTS_CACHE_TTL_SECONDS=120`, confirm
the entry expires after 2 minutes.
- [ ] **Redis unavailable:** Stop Redis (or break the connection).
Verify the TTS endpoint still works — synthesis falls back to direct
calls, with a `TTS cache lookup failed` / `TTS cache store failed`
warning logged.
- [ ] **Failure path:** Configure a TTS model with an invalid API key,
ensure the response still returns successfully with `audio_binary=None`
(no regression vs. current behavior).
2026-05-19 14:20:40 +08:00

121 lines
3.5 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import binascii
import hashlib
import logging
import os
from typing import Any, Optional
from rag.utils.redis_conn import REDIS_CONN
_DEFAULT_TTL_SECONDS = 7 * 24 * 60 * 60
_KEY_PREFIX = "tts:cache:"
def _ttl_seconds() -> int:
raw = os.environ.get("RAGFLOW_TTS_CACHE_TTL_SECONDS")
if not raw:
return _DEFAULT_TTL_SECONDS
try:
v = int(raw)
return v if v > 0 else 0
except ValueError:
logging.warning("Invalid RAGFLOW_TTS_CACHE_TTL_SECONDS=%r, using default", raw)
return _DEFAULT_TTL_SECONDS
def _model_id(tts_mdl: Any) -> Optional[str]:
cfg = getattr(tts_mdl, "model_config", None)
if isinstance(cfg, dict):
mid = cfg.get("id")
if mid is not None:
return str(mid)
name = cfg.get("llm_name") or cfg.get("model_name")
if name:
return str(name)
return None
def _build_key(tts_mdl: Any, text: str) -> Optional[str]:
mid = _model_id(tts_mdl)
if not mid:
return None
digest = hashlib.sha256(text.encode("utf-8", "ignore")).hexdigest()
return f"{_KEY_PREFIX}{mid}:{digest}"
def _to_hex_string(value: Any) -> Optional[str]:
if value is None:
return None
if isinstance(value, bytes):
try:
return value.decode("utf-8")
except Exception:
return None
if isinstance(value, str):
return value
return None
def synthesize_with_cache(tts_mdl: Any, cleaned_text: str) -> Optional[str]:
"""
Synthesize ``cleaned_text`` through ``tts_mdl`` and return a hex-encoded
audio blob, reusing a Redis-cached result when available.
The cache key is derived from the TTS model identifier and a SHA-256 of the
text, so different models keep separate caches and the same text on the
same model resolves to the same key regardless of call site. Returns
``None`` on synthesis failure; callers should treat that as a no-op the
same way they do today.
"""
if not tts_mdl or not cleaned_text:
return None
key = _build_key(tts_mdl, cleaned_text)
if key:
try:
cached = REDIS_CONN.get(key)
except Exception as e:
logging.warning("TTS cache lookup failed: %s", e)
cached = None
hex_cached = _to_hex_string(cached)
if hex_cached:
return hex_cached
buf = b""
try:
for chunk in tts_mdl.tts(cleaned_text):
if isinstance(chunk, (bytes, bytearray)):
buf += bytes(chunk)
except Exception as e:
logging.error("TTS failed: %s (text length=%d)", e, len(cleaned_text))
return None
if not buf:
return None
hex_value = binascii.hexlify(buf).decode("utf-8")
ttl = _ttl_seconds()
if key and ttl > 0:
try:
REDIS_CONN.set(key, hex_value, exp=ttl)
except Exception as e:
logging.warning("TTS cache store failed: %s", e)
return hex_value