refactor: remove union types

Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
Stream
2026-01-31 00:39:57 +08:00
parent a87560d667
commit 9ad49340bf
13 changed files with 257 additions and 361 deletions

View File

@ -12,13 +12,13 @@ from typing import TYPE_CHECKING, Any, Literal, cast
from sqlalchemy import select
from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentResult, AgentToolEntity, ExecutionContext
from core.agent.output_tools import build_agent_output_tools, select_output_tool_names
from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext
from core.agent.output_tools import build_agent_output_tools
from core.agent.patterns import StrategyFactory
from core.app.entities.app_asset_entities import AppAssetFileTree
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app_assets.constants import AppAssetsAttrs
from core.file import File, FileTransferMethod, FileType, file_manager
from core.file import FileTransferMethod, FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
@ -188,12 +188,11 @@ class LLMNode(Node[LLMNodeData]):
def _run(self) -> Generator:
node_inputs: dict[str, Any] = {}
process_data: dict[str, Any] = {}
clean_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
reasoning_content = "" # Initialize as empty string for consistency
usage: LLMUsage = LLMUsage.empty_usage()
finish_reason: str | None = None
reasoning_content: str = "" # Initialize as empty string for consistency
clean_text = "" # Initialize clean_text to avoid UnboundLocalError
variable_pool = self.graph_runtime_state.variable_pool
variable_pool: VariablePool = self.graph_runtime_state.variable_pool
try:
# Parse prompt template to separate static messages and context references
@ -253,8 +252,9 @@ class LLMNode(Node[LLMNodeData]):
)
query: str | None = None
if self.node_data.memory:
query = self.node_data.memory.query_prompt_template
memory_config = self.node_data.memory
if memory_config:
query = memory_config.query_prompt_template
if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
):
@ -296,9 +296,16 @@ class LLMNode(Node[LLMNodeData]):
sandbox=self.graph_runtime_state.sandbox,
)
# Variables for outputs
generation_data: LLMGenerationData | None = None
structured_output: LLMStructuredOutput | None = None
structured_output_schema: Mapping[str, Any] | None
if self.node_data.structured_output_enabled:
if not self.node_data.structured_output:
raise ValueError("structured_output_enabled is True but structured_output is not set")
structured_output_schema = LLMNode.fetch_structured_output_schema(
structured_output=self.node_data.structured_output
)
else:
structured_output_schema = None
if self.node_data.computer_use:
sandbox = self.graph_runtime_state.sandbox
@ -312,6 +319,7 @@ class LLMNode(Node[LLMNodeData]):
stop=stop,
variable_pool=variable_pool,
tool_dependencies=tool_dependencies,
structured_output_schema=structured_output_schema
)
elif self.tool_call_enabled:
generator = self._invoke_llm_with_tools(
@ -322,6 +330,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool=variable_pool,
node_inputs=node_inputs,
process_data=process_data,
structured_output_schema=structured_output_schema
)
else:
# Use traditional LLM invocation
@ -331,8 +340,7 @@ class LLMNode(Node[LLMNodeData]):
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=self._node_data.structured_output,
structured_output_schema=structured_output_schema,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
@ -498,8 +506,7 @@ class LLMNode(Node[LLMNodeData]):
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None = None,
user_id: str,
structured_output_enabled: bool,
structured_output: Mapping[str, Any] | None = None,
structured_output_schema: Mapping[str, Any] | None,
file_saver: LLMFileSaver,
file_outputs: list[File],
node_id: str,
@ -513,10 +520,7 @@ class LLMNode(Node[LLMNodeData]):
if not model_schema:
raise ValueError(f"Model schema not found for {node_data_model.name}")
if structured_output_enabled:
output_schema = LLMNode.fetch_structured_output_schema(
structured_output=structured_output or {},
)
if structured_output_schema:
request_start_time = time.perf_counter()
invoke_result = invoke_llm_with_structured_output(
@ -524,7 +528,7 @@ class LLMNode(Node[LLMNodeData]):
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=output_schema,
json_schema=structured_output_schema,
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
user=user_id,
@ -1876,6 +1880,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
node_inputs: dict[str, Any],
process_data: dict[str, Any],
structured_output_schema: Mapping[str, Any] | None
) -> Generator[NodeEventBase, None, LLMGenerationData]:
"""Invoke LLM with tools support (from Agent V2).
@ -1892,20 +1897,23 @@ class LLMNode(Node[LLMNodeData]):
# Use factory to create appropriate strategy
strategy = StrategyFactory.create_strategy(
tenant_id=self.tenant_id,
invoke_from=self.invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
model_features=model_features,
model_instance=model_instance,
tools=tool_instances,
files=prompt_files,
max_iterations=self._node_data.max_iterations or 10,
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
structured_output_schema=structured_output_schema
)
# Run strategy
outputs = strategy.run(
prompt_messages=list(prompt_messages),
model_parameters=self._node_data.model.completion_params,
stop=list(stop or []),
stream=False,
stop=list(stop or [])
)
result = yield from self._process_tool_outputs(outputs)
@ -1919,6 +1927,7 @@ class LLMNode(Node[LLMNodeData]):
stop: Sequence[str] | None,
variable_pool: VariablePool,
tool_dependencies: ToolDependencies | None,
structured_output_schema: Mapping[str, Any] | None
) -> Generator[NodeEventBase, None, LLMGenerationData]:
result: LLMGenerationData | None = None
sandbox_output_files: list[File] = []
@ -1927,37 +1936,25 @@ class LLMNode(Node[LLMNodeData]):
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
prompt_files = self._extract_prompt_files(variable_pool)
model_features = self._get_model_features(model_instance)
structured_output_schema = None
if self._node_data.structured_output_enabled:
structured_output_schema = LLMNode.fetch_structured_output_schema(
structured_output=self._node_data.structured_output or {},
)
output_tools = build_agent_output_tools(
strategy = StrategyFactory.create_strategy(
tenant_id=self.tenant_id,
invoke_from=self.invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
output_tool_names=select_output_tool_names(
structured_output_enabled=self._node_data.structured_output_enabled,
include_illegal_output=True,
),
structured_output_schema=structured_output_schema,
)
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=model_instance,
tools=[session.bash_tool, *output_tools],
tools=[session.bash_tool],
files=prompt_files,
max_iterations=self._node_data.max_iterations or 100,
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
structured_output_schema=structured_output_schema
)
outputs = strategy.run(
prompt_messages=list(prompt_messages),
model_parameters=self._node_data.model.completion_params,
stop=list(stop or []),
stream=False,
stop=list(stop or [])
)
result = yield from self._process_tool_outputs(outputs)
@ -2055,16 +2052,9 @@ class LLMNode(Node[LLMNodeData]):
structured_output=self._node_data.structured_output or {},
)
tool_instances.extend(
build_agent_output_tools(
tenant_id=self.tenant_id,
invoke_from=self.invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
output_tool_names=select_output_tool_names(
structured_output_enabled=self._node_data.structured_output_enabled,
include_illegal_output=True,
),
structured_output_schema=structured_output_schema,
)
build_agent_output_tools(tenant_id=self.tenant_id, invoke_from=self.invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
structured_output_schema=structured_output_schema)
)
return tool_instances
@ -2485,6 +2475,7 @@ class LLMNode(Node[LLMNodeData]):
content_position = 0
tool_call_seen_index: dict[str, int] = {}
for trace_segment in trace_state.trace_segments:
# FIXME: These if will never happen
if trace_segment.type == "thought":
sequence.append({"type": "reasoning", "index": reasoning_index})
reasoning_index += 1
@ -2564,32 +2555,25 @@ class LLMNode(Node[LLMNodeData]):
if not isinstance(exception.value, AgentResult):
raise ValueError(f"Unexpected output type: {type(exception.value)}") from exception
state.agent.agent_result = exception.value
if not state.agent.agent_result:
agent_result = state.agent.agent_result
if not agent_result:
raise ValueError("No agent result found in tool outputs")
output_payload = state.agent.agent_result.output
structured_output_data: Mapping[str, Any] | None = None
if isinstance(output_payload, AgentResult.StructuredOutput):
output_kind = output_payload.output_kind
if output_kind == AgentOutputKind.ILLEGAL_OUTPUT:
raise ValueError("Agent returned illegal output")
if output_kind in {AgentOutputKind.FINAL_OUTPUT_ANSWER, AgentOutputKind.OUTPUT_TEXT}:
if not output_payload.output_text:
raise ValueError("Agent returned empty text output")
state.aggregate.text = output_payload.output_text
elif output_kind == AgentOutputKind.FINAL_STRUCTURED_OUTPUT:
if output_payload.output_data is None:
raise ValueError("Agent returned empty structured output")
else:
raise ValueError("Agent returned unsupported output kind")
if output_payload.output_data is not None:
if not isinstance(output_payload.output_data, Mapping):
raise ValueError("Agent returned invalid structured output")
structured_output_data = output_payload.output_data
output_payload = agent_result.output
if isinstance(output_payload, dict):
state.aggregate.structured_output = LLMStructuredOutput(
structured_output=convert_file_refs_in_output(
output=output_payload,
json_schema=LLMNode.fetch_structured_output_schema(
structured_output=self._node_data.structured_output or {},
),
tenant_id=self.tenant_id,
)
)
state.aggregate.text = json.dumps(output_payload)
elif isinstance(output_payload, str):
state.aggregate.text = output_payload
else:
if not output_payload:
raise ValueError("Agent returned empty output")
state.aggregate.text = str(output_payload)
raise ValueError(f"Unexpected output type: {type(output_payload)}")
state.aggregate.files = state.agent.agent_result.files
if state.agent.agent_result.usage:
@ -2597,17 +2581,6 @@ class LLMNode(Node[LLMNodeData]):
if state.agent.agent_result.finish_reason:
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
if structured_output_data is not None:
output_schema = LLMNode.fetch_structured_output_schema(
structured_output=self._node_data.structured_output or {},
)
converted_output = convert_file_refs_in_output(
output=structured_output_data,
json_schema=output_schema,
tenant_id=self.tenant_id,
)
state.aggregate.structured_output = LLMStructuredOutput(structured_output=converted_output)
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
yield from self._close_streams()