mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-19 11:45:10 +08:00
Refactor: Enhance delta streaming in chat functions for improved reasoning and content handling (#12453)
### What problem does this PR solve? change: Enhance delta streaming in chat functions for improved reasoning and content handling ### Type of change - [x] Refactoring
This commit is contained in:
@ -37,9 +37,11 @@ class DeepResearcher:
|
||||
self._kg_retrieve = kg_retrieve
|
||||
|
||||
def _remove_tags(text: str, start_tag: str, end_tag: str) -> str:
|
||||
"""General Tag Removal Method"""
|
||||
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
|
||||
return re.sub(pattern, "", text)
|
||||
"""Remove tags but keep the content between them."""
|
||||
if not text:
|
||||
return text
|
||||
text = re.sub(re.escape(start_tag), "", text)
|
||||
return re.sub(re.escape(end_tag), "", text)
|
||||
|
||||
@staticmethod
|
||||
def _remove_query_tags(text: str) -> str:
|
||||
@ -52,21 +54,29 @@ class DeepResearcher:
|
||||
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT)
|
||||
|
||||
async def _generate_reasoning(self, msg_history):
|
||||
"""Generate reasoning steps"""
|
||||
query_think = ""
|
||||
"""Generate reasoning steps (delta output)"""
|
||||
raw_answer = ""
|
||||
cleaned_answer = ""
|
||||
if msg_history[-1]["role"] != "user":
|
||||
msg_history.append({"role": "user", "content": "Continues reasoning with the new information.\n"})
|
||||
else:
|
||||
msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
|
||||
|
||||
async for ans in self.chat_mdl.async_chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}):
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
if not ans:
|
||||
|
||||
async for delta in self.chat_mdl.async_chat_streamly_delta(REASON_PROMPT, msg_history, {"temperature": 0.7}):
|
||||
if not delta:
|
||||
continue
|
||||
query_think = ans
|
||||
yield query_think
|
||||
query_think = ""
|
||||
yield query_think
|
||||
raw_answer += delta
|
||||
cleaned_full = re.sub(r"^.*</think>", "", raw_answer, flags=re.DOTALL)
|
||||
if not cleaned_full:
|
||||
continue
|
||||
if cleaned_full.startswith(cleaned_answer):
|
||||
delta_clean = cleaned_full[len(cleaned_answer):]
|
||||
else:
|
||||
delta_clean = cleaned_full
|
||||
if not delta_clean:
|
||||
continue
|
||||
cleaned_answer = cleaned_full
|
||||
yield delta_clean
|
||||
|
||||
def _extract_search_queries(self, query_think, question, step_index):
|
||||
"""Extract search queries from thinking"""
|
||||
@ -93,7 +103,7 @@ class DeepResearcher:
|
||||
else:
|
||||
if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
|
||||
truncated_prev_reasoning += '...\n\n'
|
||||
|
||||
|
||||
return truncated_prev_reasoning.strip('\n')
|
||||
|
||||
def _retrieve_information(self, search_query):
|
||||
@ -138,16 +148,17 @@ class DeepResearcher:
|
||||
for c in kbinfos["chunks"]:
|
||||
if c["chunk_id"] not in cids:
|
||||
chunk_info["chunks"].append(c)
|
||||
|
||||
|
||||
dids = [d["doc_id"] for d in chunk_info["doc_aggs"]]
|
||||
for d in kbinfos["doc_aggs"]:
|
||||
if d["doc_id"] not in dids:
|
||||
chunk_info["doc_aggs"].append(d)
|
||||
|
||||
async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
|
||||
"""Extract and summarize relevant information"""
|
||||
summary_think = ""
|
||||
async for ans in self.chat_mdl.async_chat_streamly(
|
||||
"""Extract and summarize relevant information (delta output)"""
|
||||
raw_answer = ""
|
||||
cleaned_answer = ""
|
||||
async for delta in self.chat_mdl.async_chat_streamly_delta(
|
||||
RELEVANT_EXTRACTION_PROMPT.format(
|
||||
prev_reasoning=truncated_prev_reasoning,
|
||||
search_query=search_query,
|
||||
@ -156,39 +167,92 @@ class DeepResearcher:
|
||||
[{"role": "user",
|
||||
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
|
||||
{"temperature": 0.7}):
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
if not ans:
|
||||
if not delta:
|
||||
continue
|
||||
summary_think = ans
|
||||
yield summary_think
|
||||
summary_think = ""
|
||||
|
||||
yield summary_think
|
||||
raw_answer += delta
|
||||
cleaned_full = re.sub(r"^.*</think>", "", raw_answer, flags=re.DOTALL)
|
||||
if not cleaned_full:
|
||||
continue
|
||||
if cleaned_full.startswith(cleaned_answer):
|
||||
delta_clean = cleaned_full[len(cleaned_answer):]
|
||||
else:
|
||||
delta_clean = cleaned_full
|
||||
if not delta_clean:
|
||||
continue
|
||||
cleaned_answer = cleaned_full
|
||||
yield delta_clean
|
||||
|
||||
async def thinking(self, chunk_info: dict, question: str):
|
||||
executed_search_queries = []
|
||||
msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
|
||||
all_reasoning_steps = []
|
||||
think = "<think>"
|
||||
|
||||
last_idx = 0
|
||||
endswith_think = False
|
||||
last_full = ""
|
||||
|
||||
def emit_delta(full_text: str):
|
||||
nonlocal last_idx, endswith_think, last_full
|
||||
if full_text == last_full:
|
||||
return None
|
||||
last_full = full_text
|
||||
delta_ans = full_text[last_idx:]
|
||||
|
||||
if delta_ans.find("<think>") == 0:
|
||||
last_idx += len("<think>")
|
||||
delta = "<think>"
|
||||
elif delta_ans.find("<think>") > 0:
|
||||
delta = full_text[last_idx:last_idx + delta_ans.find("<think>")]
|
||||
last_idx += delta_ans.find("<think>")
|
||||
elif delta_ans.endswith("</think>"):
|
||||
endswith_think = True
|
||||
delta = re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
elif endswith_think:
|
||||
endswith_think = False
|
||||
delta = "</think>"
|
||||
else:
|
||||
last_idx = len(full_text)
|
||||
if full_text.endswith("</think>"):
|
||||
last_idx -= len("</think>")
|
||||
delta = re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
|
||||
if not delta:
|
||||
return None
|
||||
if delta == "<think>":
|
||||
return {"answer": "", "reference": {}, "audio_binary": None, "final": False, "start_to_think": True}
|
||||
if delta == "</think>":
|
||||
return {"answer": "", "reference": {}, "audio_binary": None, "final": False, "end_to_think": True}
|
||||
return {"answer": delta, "reference": {}, "audio_binary": None, "final": False}
|
||||
|
||||
def flush_think_close():
|
||||
nonlocal endswith_think
|
||||
if endswith_think:
|
||||
endswith_think = False
|
||||
return {"answer": "", "reference": {}, "audio_binary": None, "final": False, "end_to_think": True}
|
||||
return None
|
||||
|
||||
for step_index in range(MAX_SEARCH_LIMIT + 1):
|
||||
# Check if the maximum search limit has been reached
|
||||
if step_index == MAX_SEARCH_LIMIT - 1:
|
||||
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n"
|
||||
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
|
||||
payload = emit_delta(think + summary_think)
|
||||
if payload:
|
||||
yield payload
|
||||
all_reasoning_steps.append(summary_think)
|
||||
msg_history.append({"role": "assistant", "content": summary_think})
|
||||
break
|
||||
|
||||
# Step 1: Generate reasoning
|
||||
query_think = ""
|
||||
async for ans in self._generate_reasoning(msg_history):
|
||||
query_think = ans
|
||||
yield {"answer": think + self._remove_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||
async for delta in self._generate_reasoning(msg_history):
|
||||
query_think += delta
|
||||
payload = emit_delta(think + self._remove_query_tags(query_think))
|
||||
if payload:
|
||||
yield payload
|
||||
|
||||
think += self._remove_query_tags(query_think)
|
||||
all_reasoning_steps.append(query_think)
|
||||
|
||||
|
||||
# Step 2: Extract search queries
|
||||
queries = self._extract_search_queries(query_think, question, step_index)
|
||||
if not queries and step_index > 0:
|
||||
@ -197,42 +261,51 @@ class DeepResearcher:
|
||||
|
||||
# Process each search query
|
||||
for search_query in queries:
|
||||
logging.info(f"[THINK]Query: {step_index}. {search_query}")
|
||||
msg_history.append({"role": "assistant", "content": search_query})
|
||||
think += f"\n\n> {step_index + 1}. {search_query}\n\n"
|
||||
yield {"answer": think + "</think>", "reference": {}, "audio_binary": None}
|
||||
payload = emit_delta(think)
|
||||
if payload:
|
||||
yield payload
|
||||
|
||||
# Check if the query has already been executed
|
||||
if search_query in executed_search_queries:
|
||||
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n"
|
||||
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
|
||||
payload = emit_delta(think + summary_think)
|
||||
if payload:
|
||||
yield payload
|
||||
all_reasoning_steps.append(summary_think)
|
||||
msg_history.append({"role": "user", "content": summary_think})
|
||||
think += summary_think
|
||||
continue
|
||||
|
||||
|
||||
executed_search_queries.append(search_query)
|
||||
|
||||
|
||||
# Step 3: Truncate previous reasoning steps
|
||||
truncated_prev_reasoning = self._truncate_previous_reasoning(all_reasoning_steps)
|
||||
|
||||
|
||||
# Step 4: Retrieve information
|
||||
kbinfos = self._retrieve_information(search_query)
|
||||
|
||||
|
||||
# Step 5: Update chunk information
|
||||
self._update_chunk_info(chunk_info, kbinfos)
|
||||
|
||||
|
||||
# Step 6: Extract relevant information
|
||||
think += "\n\n"
|
||||
summary_think = ""
|
||||
async for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
|
||||
summary_think = ans
|
||||
yield {"answer": think + self._remove_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||
async for delta in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
|
||||
summary_think += delta
|
||||
payload = emit_delta(think + self._remove_result_tags(summary_think))
|
||||
if payload:
|
||||
yield payload
|
||||
|
||||
all_reasoning_steps.append(summary_think)
|
||||
msg_history.append(
|
||||
{"role": "user", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"})
|
||||
think += self._remove_result_tags(summary_think)
|
||||
logging.info(f"[THINK]Summary: {step_index}. {summary_think}")
|
||||
|
||||
yield think + "</think>"
|
||||
final_payload = emit_delta(think + "</think>")
|
||||
if final_payload:
|
||||
yield final_payload
|
||||
close_payload = flush_think_close()
|
||||
if close_payload:
|
||||
yield close_payload
|
||||
|
||||
@ -304,9 +304,12 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
# The choices field on the last chunk will always be an empty array [].
|
||||
async def streamed_response_generator(chat_id, dia, msg):
|
||||
token_used = 0
|
||||
answer_cache = ""
|
||||
reasoning_cache = ""
|
||||
last_ans = {}
|
||||
full_content = ""
|
||||
full_reasoning = ""
|
||||
final_answer = None
|
||||
final_reference = None
|
||||
in_think = False
|
||||
response = {
|
||||
"id": f"chatcmpl-{chat_id}",
|
||||
"choices": [
|
||||
@ -336,47 +339,30 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
chat_kwargs["doc_ids"] = doc_ids_str
|
||||
async for ans in async_chat(dia, msg, True, **chat_kwargs):
|
||||
last_ans = ans
|
||||
answer = ans["answer"]
|
||||
|
||||
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
|
||||
if reasoning_match:
|
||||
reasoning_part = reasoning_match.group(1)
|
||||
content_part = answer[reasoning_match.end() :]
|
||||
else:
|
||||
reasoning_part = ""
|
||||
content_part = answer
|
||||
|
||||
reasoning_incremental = ""
|
||||
if reasoning_part:
|
||||
if reasoning_part.startswith(reasoning_cache):
|
||||
reasoning_incremental = reasoning_part.replace(reasoning_cache, "", 1)
|
||||
else:
|
||||
reasoning_incremental = reasoning_part
|
||||
reasoning_cache = reasoning_part
|
||||
|
||||
content_incremental = ""
|
||||
if content_part:
|
||||
if content_part.startswith(answer_cache):
|
||||
content_incremental = content_part.replace(answer_cache, "", 1)
|
||||
else:
|
||||
content_incremental = content_part
|
||||
answer_cache = content_part
|
||||
|
||||
token_used += len(reasoning_incremental) + len(content_incremental)
|
||||
|
||||
if not any([reasoning_incremental, content_incremental]):
|
||||
if ans.get("final"):
|
||||
if ans.get("answer"):
|
||||
full_content = ans["answer"]
|
||||
final_answer = ans.get("answer") or full_content
|
||||
final_reference = ans.get("reference", {})
|
||||
continue
|
||||
|
||||
if reasoning_incremental:
|
||||
response["choices"][0]["delta"]["reasoning_content"] = reasoning_incremental
|
||||
else:
|
||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||
|
||||
if content_incremental:
|
||||
response["choices"][0]["delta"]["content"] = content_incremental
|
||||
else:
|
||||
if ans.get("start_to_think"):
|
||||
in_think = True
|
||||
continue
|
||||
if ans.get("end_to_think"):
|
||||
in_think = False
|
||||
continue
|
||||
delta = ans.get("answer") or ""
|
||||
if not delta:
|
||||
continue
|
||||
token_used += len(delta)
|
||||
if in_think:
|
||||
full_reasoning += delta
|
||||
response["choices"][0]["delta"]["reasoning_content"] = delta
|
||||
response["choices"][0]["delta"]["content"] = None
|
||||
|
||||
else:
|
||||
full_content += delta
|
||||
response["choices"][0]["delta"]["content"] = delta
|
||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
|
||||
@ -388,8 +374,9 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
response["choices"][0]["finish_reason"] = "stop"
|
||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
|
||||
if need_reference:
|
||||
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
|
||||
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
|
||||
reference_payload = final_reference if final_reference is not None else last_ans.get("reference", [])
|
||||
response["choices"][0]["delta"]["reference"] = chunks_format(reference_payload)
|
||||
response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content
|
||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||
yield "data:[DONE]\n\n"
|
||||
|
||||
|
||||
@ -69,6 +69,7 @@ def structure_answer(conv, ans, message_id, session_id):
|
||||
if not isinstance(reference, dict):
|
||||
reference = {}
|
||||
ans["reference"] = {}
|
||||
is_final = ans.get("final", True)
|
||||
|
||||
chunk_list = chunks_format(reference)
|
||||
|
||||
@ -81,12 +82,29 @@ def structure_answer(conv, ans, message_id, session_id):
|
||||
|
||||
if not conv.message:
|
||||
conv.message = []
|
||||
content = ans["answer"]
|
||||
if ans.get("start_to_think"):
|
||||
content = "<think>"
|
||||
elif ans.get("end_to_think"):
|
||||
content = "</think>"
|
||||
|
||||
if not conv.message or conv.message[-1].get("role", "") != "assistant":
|
||||
conv.message.append({"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id})
|
||||
conv.message.append({"role": "assistant", "content": content, "created_at": time.time(), "id": message_id})
|
||||
else:
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
|
||||
if is_final:
|
||||
if ans.get("answer"):
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
|
||||
else:
|
||||
conv.message[-1]["created_at"] = time.time()
|
||||
conv.message[-1]["id"] = message_id
|
||||
else:
|
||||
conv.message[-1]["content"] = (conv.message[-1].get("content") or "") + content
|
||||
conv.message[-1]["created_at"] = time.time()
|
||||
conv.message[-1]["id"] = message_id
|
||||
if conv.reference:
|
||||
conv.reference[-1] = reference
|
||||
should_update_reference = is_final or bool(reference.get("chunks")) or bool(reference.get("doc_aggs"))
|
||||
if should_update_reference:
|
||||
conv.reference[-1] = reference
|
||||
return ans
|
||||
|
||||
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
||||
|
||||
@ -196,19 +196,13 @@ async def async_chat_solo(dialog, messages, stream=True):
|
||||
if attachments and msg:
|
||||
msg[-1]["content"] += attachments
|
||||
if stream:
|
||||
last_ans = ""
|
||||
delta_ans = ""
|
||||
answer = ""
|
||||
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||
if kind == "marker":
|
||||
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||||
yield {"answer": "", "reference": {}, "audio_binary": None, "prompt": "", "created_at": time.time(), "final": False, **flags}
|
||||
continue
|
||||
last_ans = answer
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
||||
delta_ans = ""
|
||||
if delta_ans:
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
||||
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False}
|
||||
else:
|
||||
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
@ -434,8 +428,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
if not knowledges and prompt_config.get("empty_response"):
|
||||
empty_res = prompt_config["empty_response"]
|
||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||||
"audio_binary": tts(tts_mdl, empty_res)}
|
||||
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||
"audio_binary": tts(tts_mdl, empty_res), "final": True}
|
||||
return
|
||||
|
||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||
@ -538,21 +531,22 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
)
|
||||
|
||||
if stream:
|
||||
last_ans = ""
|
||||
answer = ""
|
||||
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
||||
if thought:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
last_state = None
|
||||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||
last_state = state
|
||||
if kind == "marker":
|
||||
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||||
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, **flags}
|
||||
continue
|
||||
last_ans = answer
|
||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
delta_ans = answer[len(last_ans):]
|
||||
if delta_ans:
|
||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
yield decorate_answer(thought + answer)
|
||||
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "final": False}
|
||||
full_answer = last_state.full_text if last_state else ""
|
||||
if full_answer:
|
||||
final = decorate_answer(thought + full_answer)
|
||||
final["final"] = True
|
||||
final["audio_binary"] = None
|
||||
final["answer"] = ""
|
||||
yield final
|
||||
else:
|
||||
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
@ -733,6 +727,84 @@ def tts(tts_mdl, text):
|
||||
return None
|
||||
return binascii.hexlify(bin).decode("utf-8")
|
||||
|
||||
|
||||
class _ThinkStreamState:
|
||||
def __init__(self) -> None:
|
||||
self.full_text = ""
|
||||
self.last_idx = 0
|
||||
self.endswith_think = False
|
||||
self.last_full = ""
|
||||
self.last_model_full = ""
|
||||
self.in_think = False
|
||||
self.buffer = ""
|
||||
|
||||
|
||||
def _next_think_delta(state: _ThinkStreamState) -> str:
|
||||
full_text = state.full_text
|
||||
if full_text == state.last_full:
|
||||
return ""
|
||||
state.last_full = full_text
|
||||
delta_ans = full_text[state.last_idx:]
|
||||
|
||||
if delta_ans.find("<think>") == 0:
|
||||
state.last_idx += len("<think>")
|
||||
return "<think>"
|
||||
if delta_ans.find("<think>") > 0:
|
||||
delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("<think>")]
|
||||
state.last_idx += delta_ans.find("<think>")
|
||||
return delta_text
|
||||
if delta_ans.endswith("</think>"):
|
||||
state.endswith_think = True
|
||||
elif state.endswith_think:
|
||||
state.endswith_think = False
|
||||
return "</think>"
|
||||
|
||||
state.last_idx = len(full_text)
|
||||
if full_text.endswith("</think>"):
|
||||
state.last_idx -= len("</think>")
|
||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
|
||||
|
||||
async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
|
||||
state = _ThinkStreamState()
|
||||
async for chunk in stream_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
if chunk.startswith(state.last_model_full):
|
||||
new_part = chunk[len(state.last_model_full):]
|
||||
state.last_model_full = chunk
|
||||
else:
|
||||
new_part = chunk
|
||||
state.last_model_full += chunk
|
||||
if not new_part:
|
||||
continue
|
||||
state.full_text += new_part
|
||||
delta = _next_think_delta(state)
|
||||
if not delta:
|
||||
continue
|
||||
if delta in ("<think>", "</think>"):
|
||||
if delta == "<think>" and state.in_think:
|
||||
continue
|
||||
if delta == "</think>" and not state.in_think:
|
||||
continue
|
||||
if state.buffer:
|
||||
yield ("text", state.buffer, state)
|
||||
state.buffer = ""
|
||||
state.in_think = delta == "<think>"
|
||||
yield ("marker", delta, state)
|
||||
continue
|
||||
state.buffer += delta
|
||||
if num_tokens_from_string(state.buffer) < min_tokens:
|
||||
continue
|
||||
yield ("text", state.buffer, state)
|
||||
state.buffer = ""
|
||||
|
||||
if state.buffer:
|
||||
yield ("text", state.buffer, state)
|
||||
state.buffer = ""
|
||||
if state.endswith_think:
|
||||
yield ("marker", "</think>", state)
|
||||
|
||||
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
doc_ids = search_config.get("doc_ids", [])
|
||||
rerank_mdl = None
|
||||
@ -798,11 +870,20 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
||||
refs["chunks"] = chunks_format(refs)
|
||||
return {"answer": answer, "reference": refs}
|
||||
|
||||
answer = ""
|
||||
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
||||
answer = ans
|
||||
yield {"answer": answer, "reference": {}}
|
||||
yield decorate_answer(answer)
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(sys_prompt, msg, {"temperature": 0.1})
|
||||
last_state = None
|
||||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||
last_state = state
|
||||
if kind == "marker":
|
||||
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||||
yield {"answer": "", "reference": {}, "final": False, **flags}
|
||||
continue
|
||||
yield {"answer": value, "reference": {}, "final": False}
|
||||
full_answer = last_state.full_text if last_state else ""
|
||||
final = decorate_answer(full_answer)
|
||||
final["final"] = True
|
||||
final["answer"] = ""
|
||||
yield final
|
||||
|
||||
|
||||
async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
|
||||
@ -441,3 +441,46 @@ class LLMBundle(LLM4Tenant):
|
||||
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||
generation.end()
|
||||
return
|
||||
|
||||
async def async_chat_streamly_delta(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||
total_tokens = 0
|
||||
ans = ""
|
||||
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
|
||||
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
|
||||
elif hasattr(self.mdl, "async_chat_streamly"):
|
||||
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
|
||||
else:
|
||||
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||
|
||||
if stream_fn:
|
||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
try:
|
||||
async for txt in chat_partial(**use_kwargs):
|
||||
if isinstance(txt, int):
|
||||
total_tokens = txt
|
||||
break
|
||||
|
||||
if txt.endswith("</think>"):
|
||||
ans = ans[: -len("</think>")]
|
||||
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
ans += txt
|
||||
yield txt
|
||||
except Exception as e:
|
||||
if generation:
|
||||
generation.update(output={"error": str(e)})
|
||||
generation.end()
|
||||
raise
|
||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
if generation:
|
||||
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||
generation.end()
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user