Refactor: Enhance delta streaming in chat functions for improved reasoning and content handling (#12453)

### What problem does this PR solve?

change:
Enhance delta streaming in chat functions for improved reasoning and
content handling

### Type of change


- [x] Refactoring
This commit is contained in:
buua436
2026-01-08 13:34:16 +08:00
committed by GitHub
parent f4e2783eb4
commit 1996aa0dac
5 changed files with 325 additions and 123 deletions

View File

@ -37,9 +37,11 @@ class DeepResearcher:
self._kg_retrieve = kg_retrieve
def _remove_tags(text: str, start_tag: str, end_tag: str) -> str:
"""General Tag Removal Method"""
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
return re.sub(pattern, "", text)
"""Remove tags but keep the content between them."""
if not text:
return text
text = re.sub(re.escape(start_tag), "", text)
return re.sub(re.escape(end_tag), "", text)
@staticmethod
def _remove_query_tags(text: str) -> str:
@ -52,21 +54,29 @@ class DeepResearcher:
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT)
async def _generate_reasoning(self, msg_history):
"""Generate reasoning steps"""
query_think = ""
"""Generate reasoning steps (delta output)"""
raw_answer = ""
cleaned_answer = ""
if msg_history[-1]["role"] != "user":
msg_history.append({"role": "user", "content": "Continues reasoning with the new information.\n"})
else:
msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
async for ans in self.chat_mdl.async_chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}):
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if not ans:
async for delta in self.chat_mdl.async_chat_streamly_delta(REASON_PROMPT, msg_history, {"temperature": 0.7}):
if not delta:
continue
query_think = ans
yield query_think
query_think = ""
yield query_think
raw_answer += delta
cleaned_full = re.sub(r"^.*</think>", "", raw_answer, flags=re.DOTALL)
if not cleaned_full:
continue
if cleaned_full.startswith(cleaned_answer):
delta_clean = cleaned_full[len(cleaned_answer):]
else:
delta_clean = cleaned_full
if not delta_clean:
continue
cleaned_answer = cleaned_full
yield delta_clean
def _extract_search_queries(self, query_think, question, step_index):
"""Extract search queries from thinking"""
@ -93,7 +103,7 @@ class DeepResearcher:
else:
if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
truncated_prev_reasoning += '...\n\n'
return truncated_prev_reasoning.strip('\n')
def _retrieve_information(self, search_query):
@ -138,16 +148,17 @@ class DeepResearcher:
for c in kbinfos["chunks"]:
if c["chunk_id"] not in cids:
chunk_info["chunks"].append(c)
dids = [d["doc_id"] for d in chunk_info["doc_aggs"]]
for d in kbinfos["doc_aggs"]:
if d["doc_id"] not in dids:
chunk_info["doc_aggs"].append(d)
async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
"""Extract and summarize relevant information"""
summary_think = ""
async for ans in self.chat_mdl.async_chat_streamly(
"""Extract and summarize relevant information (delta output)"""
raw_answer = ""
cleaned_answer = ""
async for delta in self.chat_mdl.async_chat_streamly_delta(
RELEVANT_EXTRACTION_PROMPT.format(
prev_reasoning=truncated_prev_reasoning,
search_query=search_query,
@ -156,39 +167,92 @@ class DeepResearcher:
[{"role": "user",
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
{"temperature": 0.7}):
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if not ans:
if not delta:
continue
summary_think = ans
yield summary_think
summary_think = ""
yield summary_think
raw_answer += delta
cleaned_full = re.sub(r"^.*</think>", "", raw_answer, flags=re.DOTALL)
if not cleaned_full:
continue
if cleaned_full.startswith(cleaned_answer):
delta_clean = cleaned_full[len(cleaned_answer):]
else:
delta_clean = cleaned_full
if not delta_clean:
continue
cleaned_answer = cleaned_full
yield delta_clean
async def thinking(self, chunk_info: dict, question: str):
executed_search_queries = []
msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
all_reasoning_steps = []
think = "<think>"
last_idx = 0
endswith_think = False
last_full = ""
def emit_delta(full_text: str):
nonlocal last_idx, endswith_think, last_full
if full_text == last_full:
return None
last_full = full_text
delta_ans = full_text[last_idx:]
if delta_ans.find("<think>") == 0:
last_idx += len("<think>")
delta = "<think>"
elif delta_ans.find("<think>") > 0:
delta = full_text[last_idx:last_idx + delta_ans.find("<think>")]
last_idx += delta_ans.find("<think>")
elif delta_ans.endswith("</think>"):
endswith_think = True
delta = re.sub(r"(<think>|</think>)", "", delta_ans)
elif endswith_think:
endswith_think = False
delta = "</think>"
else:
last_idx = len(full_text)
if full_text.endswith("</think>"):
last_idx -= len("</think>")
delta = re.sub(r"(<think>|</think>)", "", delta_ans)
if not delta:
return None
if delta == "<think>":
return {"answer": "", "reference": {}, "audio_binary": None, "final": False, "start_to_think": True}
if delta == "</think>":
return {"answer": "", "reference": {}, "audio_binary": None, "final": False, "end_to_think": True}
return {"answer": delta, "reference": {}, "audio_binary": None, "final": False}
def flush_think_close():
nonlocal endswith_think
if endswith_think:
endswith_think = False
return {"answer": "", "reference": {}, "audio_binary": None, "final": False, "end_to_think": True}
return None
for step_index in range(MAX_SEARCH_LIMIT + 1):
# Check if the maximum search limit has been reached
if step_index == MAX_SEARCH_LIMIT - 1:
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n"
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
payload = emit_delta(think + summary_think)
if payload:
yield payload
all_reasoning_steps.append(summary_think)
msg_history.append({"role": "assistant", "content": summary_think})
break
# Step 1: Generate reasoning
query_think = ""
async for ans in self._generate_reasoning(msg_history):
query_think = ans
yield {"answer": think + self._remove_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None}
async for delta in self._generate_reasoning(msg_history):
query_think += delta
payload = emit_delta(think + self._remove_query_tags(query_think))
if payload:
yield payload
think += self._remove_query_tags(query_think)
all_reasoning_steps.append(query_think)
# Step 2: Extract search queries
queries = self._extract_search_queries(query_think, question, step_index)
if not queries and step_index > 0:
@ -197,42 +261,51 @@ class DeepResearcher:
# Process each search query
for search_query in queries:
logging.info(f"[THINK]Query: {step_index}. {search_query}")
msg_history.append({"role": "assistant", "content": search_query})
think += f"\n\n> {step_index + 1}. {search_query}\n\n"
yield {"answer": think + "</think>", "reference": {}, "audio_binary": None}
payload = emit_delta(think)
if payload:
yield payload
# Check if the query has already been executed
if search_query in executed_search_queries:
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n"
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
payload = emit_delta(think + summary_think)
if payload:
yield payload
all_reasoning_steps.append(summary_think)
msg_history.append({"role": "user", "content": summary_think})
think += summary_think
continue
executed_search_queries.append(search_query)
# Step 3: Truncate previous reasoning steps
truncated_prev_reasoning = self._truncate_previous_reasoning(all_reasoning_steps)
# Step 4: Retrieve information
kbinfos = self._retrieve_information(search_query)
# Step 5: Update chunk information
self._update_chunk_info(chunk_info, kbinfos)
# Step 6: Extract relevant information
think += "\n\n"
summary_think = ""
async for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
summary_think = ans
yield {"answer": think + self._remove_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None}
async for delta in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
summary_think += delta
payload = emit_delta(think + self._remove_result_tags(summary_think))
if payload:
yield payload
all_reasoning_steps.append(summary_think)
msg_history.append(
{"role": "user", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"})
think += self._remove_result_tags(summary_think)
logging.info(f"[THINK]Summary: {step_index}. {summary_think}")
yield think + "</think>"
final_payload = emit_delta(think + "</think>")
if final_payload:
yield final_payload
close_payload = flush_think_close()
if close_payload:
yield close_payload

View File

@ -304,9 +304,12 @@ async def chat_completion_openai_like(tenant_id, chat_id):
# The choices field on the last chunk will always be an empty array [].
async def streamed_response_generator(chat_id, dia, msg):
token_used = 0
answer_cache = ""
reasoning_cache = ""
last_ans = {}
full_content = ""
full_reasoning = ""
final_answer = None
final_reference = None
in_think = False
response = {
"id": f"chatcmpl-{chat_id}",
"choices": [
@ -336,47 +339,30 @@ async def chat_completion_openai_like(tenant_id, chat_id):
chat_kwargs["doc_ids"] = doc_ids_str
async for ans in async_chat(dia, msg, True, **chat_kwargs):
last_ans = ans
answer = ans["answer"]
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
if reasoning_match:
reasoning_part = reasoning_match.group(1)
content_part = answer[reasoning_match.end() :]
else:
reasoning_part = ""
content_part = answer
reasoning_incremental = ""
if reasoning_part:
if reasoning_part.startswith(reasoning_cache):
reasoning_incremental = reasoning_part.replace(reasoning_cache, "", 1)
else:
reasoning_incremental = reasoning_part
reasoning_cache = reasoning_part
content_incremental = ""
if content_part:
if content_part.startswith(answer_cache):
content_incremental = content_part.replace(answer_cache, "", 1)
else:
content_incremental = content_part
answer_cache = content_part
token_used += len(reasoning_incremental) + len(content_incremental)
if not any([reasoning_incremental, content_incremental]):
if ans.get("final"):
if ans.get("answer"):
full_content = ans["answer"]
final_answer = ans.get("answer") or full_content
final_reference = ans.get("reference", {})
continue
if reasoning_incremental:
response["choices"][0]["delta"]["reasoning_content"] = reasoning_incremental
else:
response["choices"][0]["delta"]["reasoning_content"] = None
if content_incremental:
response["choices"][0]["delta"]["content"] = content_incremental
else:
if ans.get("start_to_think"):
in_think = True
continue
if ans.get("end_to_think"):
in_think = False
continue
delta = ans.get("answer") or ""
if not delta:
continue
token_used += len(delta)
if in_think:
full_reasoning += delta
response["choices"][0]["delta"]["reasoning_content"] = delta
response["choices"][0]["delta"]["content"] = None
else:
full_content += delta
response["choices"][0]["delta"]["content"] = delta
response["choices"][0]["delta"]["reasoning_content"] = None
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
except Exception as e:
response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
@ -388,8 +374,9 @@ async def chat_completion_openai_like(tenant_id, chat_id):
response["choices"][0]["finish_reason"] = "stop"
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
if need_reference:
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
reference_payload = final_reference if final_reference is not None else last_ans.get("reference", [])
response["choices"][0]["delta"]["reference"] = chunks_format(reference_payload)
response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
yield "data:[DONE]\n\n"

View File

@ -69,6 +69,7 @@ def structure_answer(conv, ans, message_id, session_id):
if not isinstance(reference, dict):
reference = {}
ans["reference"] = {}
is_final = ans.get("final", True)
chunk_list = chunks_format(reference)
@ -81,12 +82,29 @@ def structure_answer(conv, ans, message_id, session_id):
if not conv.message:
conv.message = []
content = ans["answer"]
if ans.get("start_to_think"):
content = "<think>"
elif ans.get("end_to_think"):
content = "</think>"
if not conv.message or conv.message[-1].get("role", "") != "assistant":
conv.message.append({"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id})
conv.message.append({"role": "assistant", "content": content, "created_at": time.time(), "id": message_id})
else:
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
if is_final:
if ans.get("answer"):
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
else:
conv.message[-1]["created_at"] = time.time()
conv.message[-1]["id"] = message_id
else:
conv.message[-1]["content"] = (conv.message[-1].get("content") or "") + content
conv.message[-1]["created_at"] = time.time()
conv.message[-1]["id"] = message_id
if conv.reference:
conv.reference[-1] = reference
should_update_reference = is_final or bool(reference.get("chunks")) or bool(reference.get("doc_aggs"))
if should_update_reference:
conv.reference[-1] = reference
return ans
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):

View File

@ -196,19 +196,13 @@ async def async_chat_solo(dialog, messages, stream=True):
if attachments and msg:
msg[-1]["content"] += attachments
if stream:
last_ans = ""
delta_ans = ""
answer = ""
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting)
async for kind, value, state in _stream_with_think_delta(stream_iter):
if kind == "marker":
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
yield {"answer": "", "reference": {}, "audio_binary": None, "prompt": "", "created_at": time.time(), "final": False, **flags}
continue
last_ans = answer
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
delta_ans = ""
if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False}
else:
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
user_content = msg[-1].get("content", "[content not available]")
@ -434,8 +428,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
"audio_binary": tts(tts_mdl, empty_res)}
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
"audio_binary": tts(tts_mdl, empty_res), "final": True}
return
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
@ -538,21 +531,22 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
)
if stream:
last_ans = ""
answer = ""
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
if thought:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf)
last_state = None
async for kind, value, state in _stream_with_think_delta(stream_iter):
last_state = state
if kind == "marker":
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, **flags}
continue
last_ans = answer
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
delta_ans = answer[len(last_ans):]
if delta_ans:
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought + answer)
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "final": False}
full_answer = last_state.full_text if last_state else ""
if full_answer:
final = decorate_answer(thought + full_answer)
final["final"] = True
final["audio_binary"] = None
final["answer"] = ""
yield final
else:
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]")
@ -733,6 +727,84 @@ def tts(tts_mdl, text):
return None
return binascii.hexlify(bin).decode("utf-8")
class _ThinkStreamState:
def __init__(self) -> None:
self.full_text = ""
self.last_idx = 0
self.endswith_think = False
self.last_full = ""
self.last_model_full = ""
self.in_think = False
self.buffer = ""
def _next_think_delta(state: _ThinkStreamState) -> str:
full_text = state.full_text
if full_text == state.last_full:
return ""
state.last_full = full_text
delta_ans = full_text[state.last_idx:]
if delta_ans.find("<think>") == 0:
state.last_idx += len("<think>")
return "<think>"
if delta_ans.find("<think>") > 0:
delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("<think>")]
state.last_idx += delta_ans.find("<think>")
return delta_text
if delta_ans.endswith("</think>"):
state.endswith_think = True
elif state.endswith_think:
state.endswith_think = False
return "</think>"
state.last_idx = len(full_text)
if full_text.endswith("</think>"):
state.last_idx -= len("</think>")
return re.sub(r"(<think>|</think>)", "", delta_ans)
async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
state = _ThinkStreamState()
async for chunk in stream_iter:
if not chunk:
continue
if chunk.startswith(state.last_model_full):
new_part = chunk[len(state.last_model_full):]
state.last_model_full = chunk
else:
new_part = chunk
state.last_model_full += chunk
if not new_part:
continue
state.full_text += new_part
delta = _next_think_delta(state)
if not delta:
continue
if delta in ("<think>", "</think>"):
if delta == "<think>" and state.in_think:
continue
if delta == "</think>" and not state.in_think:
continue
if state.buffer:
yield ("text", state.buffer, state)
state.buffer = ""
state.in_think = delta == "<think>"
yield ("marker", delta, state)
continue
state.buffer += delta
if num_tokens_from_string(state.buffer) < min_tokens:
continue
yield ("text", state.buffer, state)
state.buffer = ""
if state.buffer:
yield ("text", state.buffer, state)
state.buffer = ""
if state.endswith_think:
yield ("marker", "</think>", state)
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None
@ -798,11 +870,20 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
refs["chunks"] = chunks_format(refs)
return {"answer": answer, "reference": refs}
answer = ""
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)
stream_iter = chat_mdl.async_chat_streamly_delta(sys_prompt, msg, {"temperature": 0.1})
last_state = None
async for kind, value, state in _stream_with_think_delta(stream_iter):
last_state = state
if kind == "marker":
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
yield {"answer": "", "reference": {}, "final": False, **flags}
continue
yield {"answer": value, "reference": {}, "final": False}
full_answer = last_state.full_text if last_state else ""
final = decorate_answer(full_answer)
final["final"] = True
final["answer"] = ""
yield final
async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):

View File

@ -441,3 +441,46 @@ class LLMBundle(LLM4Tenant):
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end()
return
async def async_chat_streamly_delta(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
total_tokens = 0
ans = ""
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
elif hasattr(self.mdl, "async_chat_streamly"):
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
else:
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
try:
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
break
if txt.endswith("</think>"):
ans = ans[: -len("</think>")]
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
ans += txt
yield txt
except Exception as e:
if generation:
generation.update(output={"error": str(e)})
generation.end()
raise
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
if generation:
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end()
return