mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-21 08:37:05 +08:00
## What problem does this PR solve? Closes #12017. TTS output is deterministic for a given `(model, text)` pair, so re-running the same text through the same TTS model produces the same bytes — yet `Canvas.tts` and `dialog_service.tts` re-synthesized on every request. That's slow and wastes provider quota whenever the same assistant response is replayed, shared across users, or repeated within a session. ### Change New helper `rag/utils/tts_cache.py` with `synthesize_with_cache(tts_mdl, cleaned_text)`: - **Key:** `tts:cache:{model_id}:{sha256(text)}` — separate namespace per model, identical cleaned text reuses a single entry across both call sites. - **Value:** the hex-encoded audio blob both call sites already returned. No format change for downstream consumers. - **TTL:** 7 days by default, configurable via `RAGFLOW_TTS_CACHE_TTL_SECONDS`. - **Failure modes:** a Redis hiccup falls back to direct synthesis; a failed synthesis still returns `None` (existing contract preserved). [`Canvas.tts`](https://github.com/infiniflow/ragflow/blob/main/agent/canvas.py#L683-L724) and [`dialog_service.tts`](https://github.com/infiniflow/ragflow/blob/main/api/db/services/dialog_service.py#L1367-L1380) now route through the helper; the per-file bytes-accumulation/hex-encode loop has been removed in favor of one shared implementation. ## Type of change - [x] New Feature (non-breaking change which adds functionality) ## Test plan - [ ] **Cache hit, chat path:** Configure a dialog with TTS enabled, ask the same question twice with `stream=false`. Verify the second response returns the same `audio_binary` and that the second invocation doesn't hit the TTS provider (e.g., observe provider-side logs / usage counters; check no `LLMBundle.tts can't update token usage` log line on the second run). - [ ] **Cache hit, agent path:** Same exercise via a Conversational Agent that includes a Message component playing back the answer. - [ ] **Cache isolation per model:** Switch tenant's `tts_id` between two models, run the same text against each — confirm the second model's first synthesis still happens (no cross-model hits). - [ ] **TTL override:** Set `RAGFLOW_TTS_CACHE_TTL_SECONDS=120`, confirm the entry expires after 2 minutes. - [ ] **Redis unavailable:** Stop Redis (or break the connection). Verify the TTS endpoint still works — synthesis falls back to direct calls, with a `TTS cache lookup failed` / `TTS cache store failed` warning logged. - [ ] **Failure path:** Configure a TTS model with an invalid API key, ensure the response still returns successfully with `audio_binary=None` (no regression vs. current behavior).
121 lines
3.5 KiB
Python
121 lines
3.5 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import binascii
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
from typing import Any, Optional
|
|
|
|
from rag.utils.redis_conn import REDIS_CONN
|
|
|
|
_DEFAULT_TTL_SECONDS = 7 * 24 * 60 * 60
|
|
_KEY_PREFIX = "tts:cache:"
|
|
|
|
|
|
def _ttl_seconds() -> int:
|
|
raw = os.environ.get("RAGFLOW_TTS_CACHE_TTL_SECONDS")
|
|
if not raw:
|
|
return _DEFAULT_TTL_SECONDS
|
|
try:
|
|
v = int(raw)
|
|
return v if v > 0 else 0
|
|
except ValueError:
|
|
logging.warning("Invalid RAGFLOW_TTS_CACHE_TTL_SECONDS=%r, using default", raw)
|
|
return _DEFAULT_TTL_SECONDS
|
|
|
|
|
|
def _model_id(tts_mdl: Any) -> Optional[str]:
|
|
cfg = getattr(tts_mdl, "model_config", None)
|
|
if isinstance(cfg, dict):
|
|
mid = cfg.get("id")
|
|
if mid is not None:
|
|
return str(mid)
|
|
name = cfg.get("llm_name") or cfg.get("model_name")
|
|
if name:
|
|
return str(name)
|
|
return None
|
|
|
|
|
|
def _build_key(tts_mdl: Any, text: str) -> Optional[str]:
|
|
mid = _model_id(tts_mdl)
|
|
if not mid:
|
|
return None
|
|
digest = hashlib.sha256(text.encode("utf-8", "ignore")).hexdigest()
|
|
return f"{_KEY_PREFIX}{mid}:{digest}"
|
|
|
|
|
|
def _to_hex_string(value: Any) -> Optional[str]:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, bytes):
|
|
try:
|
|
return value.decode("utf-8")
|
|
except Exception:
|
|
return None
|
|
if isinstance(value, str):
|
|
return value
|
|
return None
|
|
|
|
|
|
def synthesize_with_cache(tts_mdl: Any, cleaned_text: str) -> Optional[str]:
|
|
"""
|
|
Synthesize ``cleaned_text`` through ``tts_mdl`` and return a hex-encoded
|
|
audio blob, reusing a Redis-cached result when available.
|
|
|
|
The cache key is derived from the TTS model identifier and a SHA-256 of the
|
|
text, so different models keep separate caches and the same text on the
|
|
same model resolves to the same key regardless of call site. Returns
|
|
``None`` on synthesis failure; callers should treat that as a no-op the
|
|
same way they do today.
|
|
"""
|
|
if not tts_mdl or not cleaned_text:
|
|
return None
|
|
|
|
key = _build_key(tts_mdl, cleaned_text)
|
|
|
|
if key:
|
|
try:
|
|
cached = REDIS_CONN.get(key)
|
|
except Exception as e:
|
|
logging.warning("TTS cache lookup failed: %s", e)
|
|
cached = None
|
|
hex_cached = _to_hex_string(cached)
|
|
if hex_cached:
|
|
return hex_cached
|
|
|
|
buf = b""
|
|
try:
|
|
for chunk in tts_mdl.tts(cleaned_text):
|
|
if isinstance(chunk, (bytes, bytearray)):
|
|
buf += bytes(chunk)
|
|
except Exception as e:
|
|
logging.error("TTS failed: %s (text length=%d)", e, len(cleaned_text))
|
|
return None
|
|
|
|
if not buf:
|
|
return None
|
|
|
|
hex_value = binascii.hexlify(buf).decode("utf-8")
|
|
|
|
ttl = _ttl_seconds()
|
|
if key and ttl > 0:
|
|
try:
|
|
REDIS_CONN.set(key, hex_value, exp=ttl)
|
|
except Exception as e:
|
|
logging.warning("TTS cache store failed: %s", e)
|
|
|
|
return hex_value
|