refactor: remove union types

Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
Stream
2026-01-31 00:39:57 +08:00
parent a87560d667
commit 9ad49340bf
13 changed files with 257 additions and 361 deletions

View File

@ -5,13 +5,10 @@ import uuid
from collections.abc import Generator
from typing import Any, Literal, Protocol, Union, cast
from core.agent.entities import AgentLog, AgentOutputKind, AgentResult
from core.agent.entities import AgentLog, AgentResult
from core.agent.output_tools import (
FINAL_OUTPUT_TOOL,
FINAL_STRUCTURED_OUTPUT_TOOL,
ILLEGAL_OUTPUT_TOOL,
OUTPUT_TEXT_TOOL,
OUTPUT_TOOL_NAME_SET,
TERMINAL_OUTPUT_MESSAGE,
)
from core.file import File
from core.model_runtime.entities import (
@ -36,21 +33,11 @@ class FunctionCallStrategy(AgentPattern):
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
stop: list[str]
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the function call agent strategy."""
# Convert tools to prompt format
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
tool_instance_names = {tool.entity.identity.name for tool in self.tools}
available_output_tool_names = {tool.name for tool in prompt_tools if tool.name in OUTPUT_TOOL_NAME_SET}
if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names:
terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL
elif FINAL_OUTPUT_TOOL in available_output_tool_names:
terminal_tool_name = FINAL_OUTPUT_TOOL
else:
raise ValueError("No terminal output tool configured")
allow_illegal_output = ILLEGAL_OUTPUT_TOOL in tool_instance_names
# Initialize tracking
iteration_step: int = 1
@ -60,10 +47,10 @@ class FunctionCallStrategy(AgentPattern):
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
final_text: str = ""
structured_output_payload: dict[str, Any] | None = None
final_tool_args: dict[str, Any] = {"!!!": "!!!"}
output_text_payload: str | None = None
finish_reason: str | None = None
output_files: list[File] = [] # Track files produced by tools
terminal_output_seen = False
class _LLMInvoker(Protocol):
def invoke_llm(
@ -87,11 +74,7 @@ class FunctionCallStrategy(AgentPattern):
data={},
)
yield round_log
# On last iteration, restrict tools to output tools
if iteration_step == max_iterations:
current_tools = [tool for tool in prompt_tools if tool.name in available_output_tool_names]
else:
current_tools = prompt_tools
model_log = self._create_log(
label=f"{self.model_instance.model} Thought",
log_type=AgentLog.LogType.THOUGHT,
@ -112,7 +95,7 @@ class FunctionCallStrategy(AgentPattern):
chunks = invoker.invoke_llm(
prompt_messages=messages,
model_parameters=model_parameters,
tools=current_tools,
tools=prompt_tools,
stop=stop,
stream=False,
user=self.context.user_id,
@ -124,19 +107,15 @@ class FunctionCallStrategy(AgentPattern):
chunks, round_usage, model_log, emit_chunks=False
)
if not tool_calls:
if not allow_illegal_output:
raise ValueError("Model did not call any tools")
tool_calls = [
(
str(uuid.uuid4()),
ILLEGAL_OUTPUT_TOOL,
{
"raw": response_content,
},
)
]
response_content = ""
if response_content:
replaced_tool_call = (
str(uuid.uuid4()),
ILLEGAL_OUTPUT_TOOL,
{
"raw": response_content,
},
)
tool_calls.append(replaced_tool_call)
messages.append(self._create_assistant_message("", tool_calls))
@ -149,35 +128,23 @@ class FunctionCallStrategy(AgentPattern):
if chunk_finish_reason:
finish_reason = chunk_finish_reason
assert len(tool_calls) > 0
# Process tool calls
tool_outputs: dict[str, str] = {}
if tool_calls:
function_call_state = True
terminal_tool_seen = False
# Execute tools
for tool_call_id, tool_name, tool_args in tool_calls:
if tool_name == OUTPUT_TEXT_TOOL:
output_text_payload = self._format_output_text(tool_args.get("text"))
elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL:
data = tool_args.get("data")
structured_output_payload = cast(dict[str, Any] | None, data)
if tool_name == terminal_tool_name:
terminal_tool_seen = True
elif tool_name == FINAL_OUTPUT_TOOL:
final_text = self._format_output_text(tool_args.get("text"))
if tool_name == terminal_tool_name:
terminal_tool_seen = True
tool_response, tool_files, _ = yield from self._handle_tool_call(
tool_name, tool_args, tool_call_id, messages, round_log
)
tool_outputs[tool_name] = tool_response
# Track files produced by tools
output_files.extend(tool_files)
if terminal_tool_seen:
terminal_output_seen = True
function_call_state = True
# Execute tools
for tool_call_id, tool_name, tool_args in tool_calls:
tool_response, tool_files, _ = yield from self._handle_tool_call(
tool_name, tool_args, tool_call_id, messages, round_log
)
tool_outputs[tool_name] = tool_response
# Track files produced by tools
output_files.extend(tool_files)
if tool_response == TERMINAL_OUTPUT_MESSAGE:
function_call_state = False
final_tool_args = tool_args
yield self._finish_log(
round_log,
data={
@ -196,31 +163,16 @@ class FunctionCallStrategy(AgentPattern):
# Return final result
from core.agent.entities import AgentResult
output_payload: str | AgentResult.StructuredOutput
if terminal_tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and terminal_output_seen:
output_payload = AgentResult.StructuredOutput(
output_kind=AgentOutputKind.FINAL_STRUCTURED_OUTPUT,
output_text=None,
output_data=structured_output_payload,
)
elif final_text:
output_payload = AgentResult.StructuredOutput(
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
output_text=final_text,
output_data=structured_output_payload,
)
elif output_text_payload:
output_payload = AgentResult.StructuredOutput(
output_kind=AgentOutputKind.OUTPUT_TEXT,
output_text=str(output_text_payload),
output_data=None,
)
output_payload: str | dict
output_text = final_tool_args.get("text")
output_structured_payload = final_tool_args.get("data")
if isinstance(output_structured_payload, dict):
output_payload = output_structured_payload
elif isinstance(output_text, str):
output_payload = output_text
else:
output_payload = AgentResult.StructuredOutput(
output_kind=AgentOutputKind.ILLEGAL_OUTPUT,
output_text="Model failed to produce a final output.",
output_data=None,
)
raise ValueError("Final output is not a string or structured data.")
return AgentResult(
output=output_payload,