mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
Merge fix/chore-fix into dev/plugin-deploy
This commit is contained in:
@ -33,7 +33,7 @@ _TEXT_COLOR_MAPPING = {
|
||||
|
||||
class WorkflowLoggingCallback(WorkflowCallback):
|
||||
def __init__(self) -> None:
|
||||
self.current_node_id = None
|
||||
self.current_node_id: Optional[str] = None
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
|
||||
@ -37,12 +37,15 @@ class NodeRunResult(BaseModel):
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[dict[str, Any]] = None # process data
|
||||
process_data: Optional[Mapping[str, Any]] = None # process data
|
||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
llm_usage: Optional[LLMUsage] = None # llm usage
|
||||
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
error: Optional[str] = None # error message if status is failed
|
||||
error_type: Optional[str] = None # error type if status is failed
|
||||
|
||||
# single step node run retry
|
||||
retry_index: int = 0
|
||||
|
||||
@ -5,7 +5,7 @@ from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState):
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
exceptions_count: Optional[int] = Field(description="exception count", default=0)
|
||||
exceptions_count: int = Field(description="exception count", default=0)
|
||||
|
||||
|
||||
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||
@ -97,6 +97,12 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
start_at: datetime = Field(..., description="retry start time")
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Branch Events
|
||||
###########################################
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
@ -170,7 +172,9 @@ class Graph(BaseModel):
|
||||
for parallel in parallel_mapping.values():
|
||||
if parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
parent_parallel_id=parallel.parent_parallel_id,
|
||||
)
|
||||
|
||||
# init answer stream generate routes
|
||||
@ -307,26 +311,17 @@ class Graph(BaseModel):
|
||||
parallel = None
|
||||
if len(target_node_edges) > 1:
|
||||
# fetch all node ids in current parallels
|
||||
parallel_branch_node_ids = {}
|
||||
condition_edge_mappings = {}
|
||||
parallel_branch_node_ids = defaultdict(list)
|
||||
condition_edge_mappings = defaultdict(list)
|
||||
for graph_edge in target_node_edges:
|
||||
if graph_edge.run_condition is None:
|
||||
if "default" not in parallel_branch_node_ids:
|
||||
parallel_branch_node_ids["default"] = []
|
||||
|
||||
parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
|
||||
else:
|
||||
condition_hash = graph_edge.run_condition.hash
|
||||
if condition_hash not in condition_edge_mappings:
|
||||
condition_edge_mappings[condition_hash] = []
|
||||
|
||||
condition_edge_mappings[condition_hash].append(graph_edge)
|
||||
|
||||
for condition_hash, graph_edges in condition_edge_mappings.items():
|
||||
if len(graph_edges) > 1:
|
||||
if condition_hash not in parallel_branch_node_ids:
|
||||
parallel_branch_node_ids[condition_hash] = []
|
||||
|
||||
for graph_edge in graph_edges:
|
||||
parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)
|
||||
|
||||
@ -415,7 +410,7 @@ class Graph(BaseModel):
|
||||
if condition_edge_mappings:
|
||||
for condition_hash, graph_edges in condition_edge_mappings.items():
|
||||
for graph_edge in graph_edges:
|
||||
current_parallel: GraphParallel | None = cls._get_current_parallel(
|
||||
current_parallel = cls._get_current_parallel(
|
||||
parallel_mapping=parallel_mapping,
|
||||
graph_edge=graph_edge,
|
||||
parallel=condition_parallels.get(condition_hash),
|
||||
|
||||
@ -6,6 +6,7 @@ import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from copy import copy, deepcopy
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
@ -26,6 +27,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
@ -39,6 +41,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
@ -65,7 +68,7 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
|
||||
self.max_submit_count = max_submit_count
|
||||
self.submit_count = 0
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
def submit(self, fn, /, *args, **kwargs):
|
||||
self.submit_count += 1
|
||||
self.check_is_full()
|
||||
|
||||
@ -139,7 +142,8 @@ class GraphEngine:
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
# trigger graph run start event
|
||||
yield GraphRunStartedEvent()
|
||||
handle_exceptions = []
|
||||
handle_exceptions: list[str] = []
|
||||
stream_processor: StreamProcessor
|
||||
|
||||
try:
|
||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||
@ -349,7 +353,7 @@ class GraphEngine:
|
||||
|
||||
if any(edge.run_condition for edge in edge_mappings):
|
||||
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
||||
condition_edge_mappings = {}
|
||||
condition_edge_mappings: dict[str, list[GraphEdge]] = {}
|
||||
for edge in edge_mappings:
|
||||
if edge.run_condition:
|
||||
run_condition_hash = edge.run_condition.hash
|
||||
@ -363,6 +367,9 @@ class GraphEngine:
|
||||
continue
|
||||
|
||||
edge = cast(GraphEdge, sub_edge_mappings[0])
|
||||
if edge.run_condition is None:
|
||||
logger.warning(f"Edge {edge.target_node_id} run condition is None")
|
||||
continue
|
||||
|
||||
result = ConditionManager.get_condition_handler(
|
||||
init_params=self.init_params,
|
||||
@ -386,11 +393,11 @@ class GraphEngine:
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
|
||||
for item in parallel_generator:
|
||||
if isinstance(item, str):
|
||||
final_node_id = item
|
||||
for parallel_result in parallel_generator:
|
||||
if isinstance(parallel_result, str):
|
||||
final_node_id = parallel_result
|
||||
else:
|
||||
yield item
|
||||
yield parallel_result
|
||||
|
||||
break
|
||||
|
||||
@ -412,11 +419,11 @@ class GraphEngine:
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
|
||||
for item in parallel_generator:
|
||||
if isinstance(item, str):
|
||||
final_node_id = item
|
||||
for generated_item in parallel_generator:
|
||||
if isinstance(generated_item, str):
|
||||
final_node_id = generated_item
|
||||
else:
|
||||
yield item
|
||||
yield generated_item
|
||||
|
||||
if not final_node_id:
|
||||
break
|
||||
@ -587,7 +594,7 @@ class GraphEngine:
|
||||
|
||||
def _run_node(
|
||||
self,
|
||||
node_instance: BaseNode,
|
||||
node_instance: BaseNode[BaseNodeData],
|
||||
route_node_state: RouteNodeState,
|
||||
parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
@ -613,36 +620,120 @@ class GraphEngine:
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
max_retries = node_instance.node_data.retry_config.max_retries
|
||||
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
||||
retries = 0
|
||||
should_continue_retry = True
|
||||
while should_continue_retry and retries <= max_retries:
|
||||
try:
|
||||
# run node
|
||||
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
generator = node_instance.run()
|
||||
for item in generator:
|
||||
if isinstance(item, GraphEngineEvent):
|
||||
if isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
for item in generator:
|
||||
if isinstance(item, GraphEngineEvent):
|
||||
if isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
yield item
|
||||
else:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if (
|
||||
retries == max_retries
|
||||
and node_instance.node_type == NodeType.HTTP_REQUEST
|
||||
and run_result.outputs
|
||||
and not node_instance.should_continue_on_error
|
||||
):
|
||||
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if node_instance.should_retry and retries < max_retries:
|
||||
retries += 1
|
||||
route_node_state.node_run_result = run_result
|
||||
yield NodeRunRetryEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
predecessor_node_id=node_instance.previous_node_id,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
error=run_result.error or "Unknown error",
|
||||
retry_index=retries,
|
||||
start_at=retry_start_at,
|
||||
)
|
||||
time.sleep(retry_interval)
|
||||
continue
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
yield item
|
||||
else:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if node_instance.should_continue_on_error:
|
||||
# if run failed, handle error
|
||||
run_result = self._handle_continue_on_error(
|
||||
node_instance,
|
||||
item.run_result,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
route_node_state.node_run_result = run_result
|
||||
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
should_continue_retry = False
|
||||
else:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
should_continue_retry = False
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
||||
node_instance.node_id
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if node_instance.should_continue_on_error:
|
||||
# if run failed, handle error
|
||||
run_result = self._handle_continue_on_error(
|
||||
node_instance,
|
||||
item.run_result,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
route_node_state.node_run_result = run_result
|
||||
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||
if run_result.llm_usage:
|
||||
# use the latest usage
|
||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
@ -651,133 +742,86 @@ class GraphEngine:
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
else:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
||||
node_instance.node_id
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if run_result.llm_usage:
|
||||
# use the latest usage
|
||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
|
||||
# add parallel info to run result metadata
|
||||
if parallel_id and parallel_start_node_id:
|
||||
# When setting metadata, convert to dict first
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||
parent_parallel_start_node_id
|
||||
)
|
||||
if parallel_id and parallel_start_node_id:
|
||||
metadata_dict = dict(run_result.metadata)
|
||||
metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||
parent_parallel_start_node_id
|
||||
)
|
||||
run_result.metadata = metadata_dict
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
should_continue_retry = False
|
||||
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
"""
|
||||
@ -836,8 +880,8 @@ class GraphEngine:
|
||||
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
|
||||
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
|
||||
# add error message to handle_exceptions
|
||||
handle_exceptions.append(error_result.error)
|
||||
node_error_args = {
|
||||
handle_exceptions.append(error_result.error or "")
|
||||
node_error_args: dict[str, Any] = {
|
||||
"status": WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
"error": error_result.error,
|
||||
"inputs": error_result.inputs,
|
||||
|
||||
@ -147,6 +147,8 @@ class AnswerStreamGeneratorRouter:
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
if source_node_id not in node_id_config_mapping:
|
||||
continue
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
|
||||
if (
|
||||
|
||||
@ -60,11 +60,10 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
|
||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||
|
||||
# remove unreachable nodes
|
||||
self._remove_unreachable_nodes(event)
|
||||
|
||||
# generate stream outputs
|
||||
yield from self._generate_stream_outputs_when_node_finished(event)
|
||||
yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event))
|
||||
else:
|
||||
yield event
|
||||
|
||||
@ -131,7 +130,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=text,
|
||||
from_variable_selector=value_selector,
|
||||
from_variable_selector=list(value_selector),
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamProcessor(ABC):
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
@ -16,7 +19,7 @@ class StreamProcessor(ABC):
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
|
||||
finished_node_id = event.route_node_state.node_id
|
||||
if finished_node_id not in self.rest_node_ids:
|
||||
return
|
||||
@ -29,15 +32,24 @@ class StreamProcessor(ABC):
|
||||
return
|
||||
|
||||
if run_result.edge_source_handle:
|
||||
reachable_node_ids = []
|
||||
unreachable_first_node_ids = []
|
||||
reachable_node_ids: list[str] = []
|
||||
unreachable_first_node_ids: list[str] = []
|
||||
if finished_node_id not in self.graph.edge_mapping:
|
||||
logger.warning(f"node {finished_node_id} has no edge mapping")
|
||||
return
|
||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
||||
if (
|
||||
edge.run_condition
|
||||
and edge.run_condition.branch_identify
|
||||
and run_result.edge_source_handle == edge.run_condition.branch_identify
|
||||
):
|
||||
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
# remove unreachable nodes
|
||||
# FIXME: because of the code branch can combine directly, so for answer node
|
||||
# we remove the node maybe shortcut the answer node, so comment this code for now
|
||||
# there is not effect on the answer node and the workflow, when we have a better solution
|
||||
# we can open this code. Issues: #11542 #9560 #10638 #10564
|
||||
|
||||
# reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
continue
|
||||
else:
|
||||
unreachable_first_node_ids.append(edge.target_node_id)
|
||||
|
||||
@ -38,7 +38,8 @@ class DefaultValue(BaseModel):
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
|
||||
"""Unified array type validation"""
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
@ -84,7 +85,7 @@ class DefaultValue(BaseModel):
|
||||
},
|
||||
}
|
||||
|
||||
validator = type_validators.get(self.type)
|
||||
validator: dict[str, Any] = type_validators.get(self.type, {})
|
||||
if not validator:
|
||||
if self.type == DefaultValueType.ARRAY_FILES:
|
||||
# Handle files type
|
||||
@ -106,12 +107,25 @@ class DefaultValue(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
error_strategy: Optional[ErrorStrategy] = None
|
||||
default_value: Optional[list[DefaultValue]] = None
|
||||
version: str = "1"
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
def default_value_dict(self):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
class BaseNodeError(Exception):
|
||||
class BaseNodeError(ValueError):
|
||||
"""Base class for node errors."""
|
||||
|
||||
pass
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
@ -72,7 +72,11 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
result = self._run()
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {self.node_id} failed to run")
|
||||
result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError")
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type="WorkflowNodeError",
|
||||
)
|
||||
|
||||
if isinstance(result, NodeRunResult):
|
||||
yield RunCompletedEvent(run_result=result)
|
||||
@ -143,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
bool: if should continue on error
|
||||
"""
|
||||
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
|
||||
|
||||
@property
|
||||
def should_retry(self) -> bool:
|
||||
"""judge if should retry
|
||||
|
||||
Returns:
|
||||
bool: if should retry
|
||||
"""
|
||||
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
@ -59,7 +59,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
)
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result, self.node_data.outputs)
|
||||
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
@ -67,18 +67,17 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
||||
def _check_string(self, value: str, variable: str) -> str:
|
||||
def _check_string(self, value: str | None, variable: str) -> str | None:
|
||||
"""
|
||||
Check string
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, str):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a string")
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a string")
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
raise OutputValidationError(
|
||||
@ -88,18 +87,17 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
|
||||
return value.replace("\x00", "")
|
||||
|
||||
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
|
||||
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
|
||||
"""
|
||||
Check number
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, int | float):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a number")
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a number")
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
raise OutputValidationError(
|
||||
@ -118,18 +116,16 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
return value
|
||||
|
||||
def _transform_result(
|
||||
self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1
|
||||
) -> dict:
|
||||
"""
|
||||
Transform result
|
||||
:param result: result
|
||||
:param output_schema: output schema
|
||||
:return:
|
||||
"""
|
||||
self,
|
||||
result: Mapping[str, Any],
|
||||
output_schema: Optional[dict[str, CodeNodeData.Output]],
|
||||
prefix: str = "",
|
||||
depth: int = 1,
|
||||
):
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
|
||||
transformed_result = {}
|
||||
transformed_result: dict[str, Any] = {}
|
||||
if output_schema is None:
|
||||
# validate output thought instance type
|
||||
for output_name, output_value in result.items():
|
||||
|
||||
@ -14,7 +14,7 @@ class CodeNodeData(BaseNodeData):
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
children: Optional[dict[str, "Output"]] = None
|
||||
children: Optional[dict[str, "CodeNodeData.Output"]] = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
name: str
|
||||
|
||||
@ -1,19 +1,15 @@
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import cast
|
||||
|
||||
import docx
|
||||
import pandas as pd
|
||||
import pypdfium2 # type: ignore
|
||||
import yaml # type: ignore
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.email import partition_email
|
||||
from unstructured.partition.epub import partition_epub
|
||||
from unstructured.partition.msg import partition_msg
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
@ -28,6 +24,8 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
||||
"""
|
||||
@ -162,7 +160,7 @@ def _extract_text_from_yaml(file_content: bytes) -> str:
|
||||
"""Extract the content from yaml file"""
|
||||
try:
|
||||
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore"))
|
||||
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
|
||||
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
|
||||
except (UnicodeDecodeError, yaml.YAMLError) as e:
|
||||
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
|
||||
|
||||
@ -183,10 +181,43 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
"""
|
||||
Extract text from a DOC/DOCX file.
|
||||
For now support only paragraph and table add more if needed
|
||||
"""
|
||||
try:
|
||||
doc_file = io.BytesIO(file_content)
|
||||
doc = docx.Document(doc_file)
|
||||
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
||||
text = []
|
||||
# Process paragraphs
|
||||
for paragraph in doc.paragraphs:
|
||||
if paragraph.text.strip():
|
||||
text.append(paragraph.text)
|
||||
|
||||
# Process tables
|
||||
for table in doc.tables:
|
||||
# Table header
|
||||
try:
|
||||
# table maybe cause errors so ignore it.
|
||||
if len(table.rows) > 0 and table.rows[0].cells is not None:
|
||||
# Check if any cell in the table has text
|
||||
has_content = False
|
||||
for row in table.rows:
|
||||
if any(cell.text.strip() for cell in row.cells):
|
||||
has_content = True
|
||||
break
|
||||
|
||||
if has_content:
|
||||
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
|
||||
for row in table.rows[1:]:
|
||||
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
|
||||
text.append(markdown_table)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
||||
continue
|
||||
|
||||
return "\n".join(text)
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
|
||||
|
||||
@ -199,9 +230,9 @@ def _download_file_content(file: File) -> bytes:
|
||||
raise FileDownloadError("Missing URL for remote file")
|
||||
response = ssrf_proxy.get(file.remote_url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
return cast(bytes, response.content)
|
||||
else:
|
||||
return file_manager.download(file)
|
||||
return cast(bytes, file_manager.download(file))
|
||||
except Exception as e:
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
@ -256,6 +287,8 @@ def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_ppt(file=file)
|
||||
@ -265,6 +298,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
try:
|
||||
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
|
||||
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
|
||||
@ -287,6 +323,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_epub(file_content: bytes) -> str:
|
||||
from unstructured.partition.epub import partition_epub
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_epub(file=file)
|
||||
@ -296,6 +334,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_eml(file_content: bytes) -> str:
|
||||
from unstructured.partition.email import partition_email
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_email(file=file)
|
||||
@ -305,6 +345,8 @@ def _extract_text_from_eml(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_msg(file_content: bytes) -> str:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_msg(file=file)
|
||||
|
||||
@ -67,7 +67,7 @@ class EndStreamGeneratorRouter:
|
||||
and node_type == NodeType.LLM.value
|
||||
and variable_selector.value_selector[1] == "text"
|
||||
):
|
||||
value_selectors.append(variable_selector.value_selector)
|
||||
value_selectors.append(list(variable_selector.value_selector))
|
||||
|
||||
return value_selectors
|
||||
|
||||
@ -119,8 +119,7 @@ class EndStreamGeneratorRouter:
|
||||
current_node_id: str,
|
||||
end_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]],
|
||||
# type: ignore[name-defined]
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
end_dependencies: dict[str, list[str]],
|
||||
) -> None:
|
||||
"""
|
||||
@ -135,6 +134,8 @@ class EndStreamGeneratorRouter:
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
if source_node_id not in node_id_config_mapping:
|
||||
continue
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
if source_node_type in {
|
||||
NodeType.IF_ELSE.value,
|
||||
|
||||
@ -23,7 +23,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
self.route_position[end_node_id] = 0
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
self.has_output = False
|
||||
self.output_node_ids = set()
|
||||
self.output_node_ids: set[str] = set()
|
||||
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
|
||||
@ -36,3 +36,4 @@ class FailBranchSourceHandle(StrEnum):
|
||||
|
||||
|
||||
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||
RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE
|
||||
|
||||
@ -1,4 +1,10 @@
|
||||
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from .event import (
|
||||
ModelInvokeCompletedEvent,
|
||||
RunCompletedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunRetryEvent,
|
||||
RunStreamChunkEvent,
|
||||
)
|
||||
from .types import NodeEvent
|
||||
|
||||
__all__ = [
|
||||
@ -6,5 +12,6 @@ __all__ = [
|
||||
"NodeEvent",
|
||||
"RunCompletedEvent",
|
||||
"RunRetrieverResourceEvent",
|
||||
"RunRetryEvent",
|
||||
"RunStreamChunkEvent",
|
||||
]
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class RunCompletedEvent(BaseModel):
|
||||
@ -26,3 +29,19 @@ class ModelInvokeCompletedEvent(BaseModel):
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
class RunRetryEvent(BaseModel):
|
||||
"""Node Run Retry event"""
|
||||
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="Retry attempt number")
|
||||
start_at: datetime = Field(..., description="Retry start time")
|
||||
|
||||
|
||||
class SingleStepRetryEvent(NodeRunResult):
|
||||
"""Single step retry event"""
|
||||
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RETRY
|
||||
|
||||
elapsed_time: float = Field(..., description="elapsed time")
|
||||
|
||||
@ -16,3 +16,7 @@ class InvalidHttpMethodError(HttpRequestNodeError):
|
||||
|
||||
class ResponseSizeError(HttpRequestNodeError):
|
||||
"""Raised when the response size exceeds the allowed threshold."""
|
||||
|
||||
|
||||
class RequestBodyError(HttpRequestNodeError):
|
||||
"""Raised when the request body is invalid."""
|
||||
|
||||
@ -23,6 +23,7 @@ from .exc import (
|
||||
FileFetchError,
|
||||
HttpRequestNodeError,
|
||||
InvalidHttpMethodError,
|
||||
RequestBodyError,
|
||||
ResponseSizeError,
|
||||
)
|
||||
|
||||
@ -45,6 +46,7 @@ class Executor:
|
||||
headers: dict[str, str]
|
||||
auth: HttpRequestNodeAuthorization
|
||||
timeout: HttpRequestNodeTimeout
|
||||
max_retries: int
|
||||
|
||||
boundary: str
|
||||
|
||||
@ -54,6 +56,7 @@ class Executor:
|
||||
node_data: HttpRequestNodeData,
|
||||
timeout: HttpRequestNodeTimeout,
|
||||
variable_pool: VariablePool,
|
||||
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
):
|
||||
# If authorization API key is present, convert the API key using the variable pool
|
||||
if node_data.authorization.type == "api-key":
|
||||
@ -73,6 +76,7 @@ class Executor:
|
||||
self.files = None
|
||||
self.data = None
|
||||
self.json = None
|
||||
self.max_retries = max_retries
|
||||
|
||||
# init template
|
||||
self.variable_pool = variable_pool
|
||||
@ -103,9 +107,9 @@ class Executor:
|
||||
if not (key := key.strip()):
|
||||
continue
|
||||
|
||||
value = value[0].strip() if value else ""
|
||||
value_str = value[0].strip() if value else ""
|
||||
result.append(
|
||||
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value).text)
|
||||
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text)
|
||||
)
|
||||
|
||||
self.params = result
|
||||
@ -140,13 +144,19 @@ class Executor:
|
||||
case "none":
|
||||
self.content = ""
|
||||
case "raw-text":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("raw-text body type should have exactly one item")
|
||||
self.content = self.variable_pool.convert_template(data[0].value).text
|
||||
case "json":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("json body type should have exactly one item")
|
||||
json_string = self.variable_pool.convert_template(data[0].value).text
|
||||
json_object = json.loads(json_string, strict=False)
|
||||
self.json = json_object
|
||||
# self.json = self._parse_object_contains_variables(json_object)
|
||||
case "binary":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("binary body type should have exactly one item")
|
||||
file_selector = data[0].file
|
||||
file_variable = self.variable_pool.get_file(file_selector)
|
||||
if file_variable is None:
|
||||
@ -172,9 +182,10 @@ class Executor:
|
||||
self.variable_pool.convert_template(item.key).text: item.file
|
||||
for item in filter(lambda item: item.type == "file", data)
|
||||
}
|
||||
files: dict[str, Any] = {}
|
||||
files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()}
|
||||
files = {k: v for k, v in files.items() if v is not None}
|
||||
files = {k: variable.value for k, variable in files.items()}
|
||||
files = {k: variable.value for k, variable in files.items() if variable is not None}
|
||||
files = {
|
||||
k: (v.filename, file_manager.download(v), v.mime_type or "application/octet-stream")
|
||||
for k, v in files.items()
|
||||
@ -241,13 +252,15 @@ class Executor:
|
||||
"params": self.params,
|
||||
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
|
||||
"follow_redirects": True,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response = getattr(ssrf_proxy, self.method)(**request_args)
|
||||
except ssrf_proxy.MaxRetriesExceededError as e:
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
raise HttpRequestNodeError(str(e))
|
||||
return response
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
return response # type: ignore
|
||||
|
||||
def invoke(self) -> Response:
|
||||
# assemble headers
|
||||
@ -289,35 +302,37 @@ class Executor:
|
||||
continue
|
||||
raw += f"{k}: {v}\r\n"
|
||||
|
||||
body = ""
|
||||
body_string = ""
|
||||
if self.files:
|
||||
for k, v in self.files.items():
|
||||
body += f"--{boundary}\r\n"
|
||||
body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
|
||||
body += f"{v[1]}\r\n"
|
||||
body += f"--{boundary}--\r\n"
|
||||
body_string += f"--{boundary}\r\n"
|
||||
body_string += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
|
||||
body_string += f"{v[1]}\r\n"
|
||||
body_string += f"--{boundary}--\r\n"
|
||||
elif self.node_data.body:
|
||||
if self.content:
|
||||
if isinstance(self.content, str):
|
||||
body = self.content
|
||||
body_string = self.content
|
||||
elif isinstance(self.content, bytes):
|
||||
body = self.content.decode("utf-8", errors="replace")
|
||||
body_string = self.content.decode("utf-8", errors="replace")
|
||||
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
|
||||
body = urlencode(self.data)
|
||||
body_string = urlencode(self.data)
|
||||
elif self.data and self.node_data.body.type == "form-data":
|
||||
for key, value in self.data.items():
|
||||
body += f"--{boundary}\r\n"
|
||||
body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
body += f"{value}\r\n"
|
||||
body += f"--{boundary}--\r\n"
|
||||
body_string += f"--{boundary}\r\n"
|
||||
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
body_string += f"{value}\r\n"
|
||||
body_string += f"--{boundary}--\r\n"
|
||||
elif self.json:
|
||||
body = json.dumps(self.json)
|
||||
body_string = json.dumps(self.json)
|
||||
elif self.node_data.body.type == "raw-text":
|
||||
body = self.node_data.body.data[0].value
|
||||
if body:
|
||||
raw += f"Content-Length: {len(body)}\r\n"
|
||||
if len(self.node_data.body.data) != 1:
|
||||
raise RequestBodyError("raw-text body type should have exactly one item")
|
||||
body_string = self.node_data.body.data[0].value
|
||||
if body_string:
|
||||
raw += f"Content-Length: {len(body_string)}\r\n"
|
||||
raw += "\r\n" # Empty line between headers and body
|
||||
raw += body
|
||||
raw += body_string
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod
|
||||
@ -19,7 +20,7 @@ from .entities import (
|
||||
HttpRequestNodeTimeout,
|
||||
Response,
|
||||
)
|
||||
from .exc import HttpRequestNodeError
|
||||
from .exc import HttpRequestNodeError, RequestBodyError
|
||||
|
||||
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
|
||||
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
|
||||
@ -35,7 +36,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
_node_type = NodeType.HTTP_REQUEST
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: dict | None = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
|
||||
return {
|
||||
"type": "http-request",
|
||||
"config": {
|
||||
@ -51,6 +52,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
||||
},
|
||||
},
|
||||
"retry_config": {
|
||||
"max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
"retry_interval": 0.5 * (2**2),
|
||||
"retry_enabled": True,
|
||||
},
|
||||
}
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
@ -60,12 +66,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
node_data=self.node_data,
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
max_retries=0,
|
||||
)
|
||||
process_data["request"] = http_executor.to_log()
|
||||
|
||||
response = http_executor.invoke()
|
||||
files = self.extract_files(url=http_executor.url, response=response)
|
||||
if not response.response.is_success and self.should_continue_on_error:
|
||||
if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
outputs={
|
||||
@ -129,9 +136,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
data = node_data.body.data
|
||||
match body_type:
|
||||
case "binary":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("invalid body data, should have only one item")
|
||||
selector = data[0].file
|
||||
selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector))
|
||||
case "json" | "raw-text":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("invalid body data, should have only one item")
|
||||
selectors += variable_template_parser.extract_selectors_from_template(data[0].key)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(data[0].value)
|
||||
case "x-www-form-urlencoded":
|
||||
@ -149,27 +160,31 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
)
|
||||
|
||||
mapping = {}
|
||||
for selector in selectors:
|
||||
mapping[node_id + "." + selector.variable] = selector.value_selector
|
||||
for selector_iter in selectors:
|
||||
mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector
|
||||
|
||||
return mapping
|
||||
|
||||
def extract_files(self, url: str, response: Response) -> list[File]:
|
||||
"""
|
||||
Extract files from response
|
||||
Extract files from response by checking both Content-Type header and URL
|
||||
"""
|
||||
files = []
|
||||
is_file = response.is_file
|
||||
content_type = response.content_type
|
||||
content = response.content
|
||||
|
||||
if is_file and content_type:
|
||||
if is_file:
|
||||
# Guess file extension from URL or Content-Type header
|
||||
filename = url.split("?")[0].split("/")[-1] or ""
|
||||
mime_type = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
|
||||
tool_file = ToolFileManager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=content,
|
||||
mimetype=content_type,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
|
||||
mapping = {
|
||||
|
||||
@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import IntegerVariable
|
||||
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
|
||||
from core.workflow.entities.node_entities import (
|
||||
NodeRunMetadataKey,
|
||||
NodeRunResult,
|
||||
@ -76,12 +76,15 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
|
||||
|
||||
if not iterator_list_segment:
|
||||
raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found")
|
||||
if not variable:
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
|
||||
|
||||
if len(iterator_list_segment.value) == 0:
|
||||
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
|
||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
|
||||
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -90,7 +93,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
iterator_list_value = iterator_list_segment.to_object()
|
||||
iterator_list_value = variable.to_object()
|
||||
|
||||
if not isinstance(iterator_list_value, list):
|
||||
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||
@ -360,13 +363,16 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
|
||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||
if self.node_data.is_parallel:
|
||||
metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
|
||||
else:
|
||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index
|
||||
metadata = {
|
||||
**metadata,
|
||||
NodeRunMetadataKey.ITERATION_ID: self.node_id,
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID
|
||||
if self.node_data.is_parallel
|
||||
else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id
|
||||
if self.node_data.is_parallel
|
||||
else iter_run_index,
|
||||
}
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
return event
|
||||
|
||||
|
||||
@ -70,7 +70,20 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
|
||||
except KnowledgeRetrievalNodeError as e:
|
||||
logger.warning("Error when running knowledge retrieval node")
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
|
||||
available_datasets = []
|
||||
@ -134,6 +147,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
planning_strategy=planning_strategy,
|
||||
)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
reranking_model = {
|
||||
@ -144,6 +159,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
reranking_model = None
|
||||
weights = None
|
||||
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
|
||||
if node_data.multiple_retrieval_config.weights is None:
|
||||
raise ValueError("weights is required")
|
||||
reranking_model = None
|
||||
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
|
||||
weights = {
|
||||
@ -160,18 +177,20 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
reranking_model = None
|
||||
weights = None
|
||||
all_documents = dataset_retrieval.multiple_retrieve(
|
||||
self.app_id,
|
||||
self.tenant_id,
|
||||
self.user_id,
|
||||
self.user_from.value,
|
||||
available_datasets,
|
||||
query,
|
||||
node_data.multiple_retrieval_config.top_k,
|
||||
node_data.multiple_retrieval_config.score_threshold,
|
||||
node_data.multiple_retrieval_config.reranking_mode,
|
||||
reranking_model,
|
||||
weights,
|
||||
node_data.multiple_retrieval_config.reranking_enable,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
available_datasets=available_datasets,
|
||||
query=query,
|
||||
top_k=node_data.multiple_retrieval_config.top_k,
|
||||
score_threshold=node_data.multiple_retrieval_config.score_threshold
|
||||
if node_data.multiple_retrieval_config.score_threshold is not None
|
||||
else 0.0,
|
||||
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
|
||||
reranking_model=reranking_model,
|
||||
weights=weights,
|
||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||
)
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
@ -192,7 +211,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
"content": item.page_content,
|
||||
}
|
||||
retrieval_resource_list.append(source)
|
||||
document_score_list = {}
|
||||
document_score_list: dict[str, float] = {}
|
||||
# deal with dify documents
|
||||
if dify_documents:
|
||||
document_score_list = {}
|
||||
@ -247,7 +266,9 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
retrieval_resource_list.append(source)
|
||||
if retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list, key=lambda x: x.get("metadata").get("score") or 0.0, reverse=True
|
||||
retrieval_resource_list,
|
||||
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
|
||||
reverse=True,
|
||||
)
|
||||
position = 1
|
||||
for item in retrieval_resource_list:
|
||||
@ -282,6 +303,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
if node_data.single_retrieval_config is None:
|
||||
raise ValueError("single_retrieval_config is required")
|
||||
model_name = node_data.single_retrieval_config.model.name
|
||||
provider_name = node_data.single_retrieval_config.model.provider
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Literal, Union
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
@ -17,9 +17,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
_node_type = NodeType.LIST_OPERATOR
|
||||
|
||||
def _run(self):
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
inputs: dict[str, list] = {}
|
||||
process_data: dict[str, list] = {}
|
||||
outputs: dict[str, Any] = {}
|
||||
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
|
||||
if variable is None:
|
||||
@ -93,6 +93,8 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
def _apply_filter(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
filter_func: Callable[[Any], bool]
|
||||
result: list[Any] = []
|
||||
for condition in self.node_data.filter_by.conditions:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
if not isinstance(condition.value, str):
|
||||
@ -236,6 +238,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
|
||||
|
||||
|
||||
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
||||
extract_func: Callable[[File], Any]
|
||||
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
@ -249,47 +252,47 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
|
||||
raise InvalidKeyError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
def _contains(value: str):
|
||||
def _contains(value: str) -> Callable[[str], bool]:
|
||||
return lambda x: value in x
|
||||
|
||||
|
||||
def _startswith(value: str):
|
||||
def _startswith(value: str) -> Callable[[str], bool]:
|
||||
return lambda x: x.startswith(value)
|
||||
|
||||
|
||||
def _endswith(value: str):
|
||||
def _endswith(value: str) -> Callable[[str], bool]:
|
||||
return lambda x: x.endswith(value)
|
||||
|
||||
|
||||
def _is(value: str):
|
||||
def _is(value: str) -> Callable[[str], bool]:
|
||||
return lambda x: x is value
|
||||
|
||||
|
||||
def _in(value: str | Sequence[str]):
|
||||
def _in(value: str | Sequence[str]) -> Callable[[str], bool]:
|
||||
return lambda x: x in value
|
||||
|
||||
|
||||
def _eq(value: int | float):
|
||||
def _eq(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x == value
|
||||
|
||||
|
||||
def _ne(value: int | float):
|
||||
def _ne(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x != value
|
||||
|
||||
|
||||
def _lt(value: int | float):
|
||||
def _lt(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x < value
|
||||
|
||||
|
||||
def _le(value: int | float):
|
||||
def _le(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x <= value
|
||||
|
||||
|
||||
def _gt(value: int | float):
|
||||
def _gt(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x > value
|
||||
|
||||
|
||||
def _ge(value: int | float):
|
||||
def _ge(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x >= value
|
||||
|
||||
|
||||
@ -302,6 +305,7 @@ def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
|
||||
|
||||
|
||||
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
|
||||
extract_func: Callable[[File], Any]
|
||||
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
|
||||
extract_func = _get_file_extract_string_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
|
||||
@ -50,6 +50,7 @@ class PromptConfig(BaseModel):
|
||||
|
||||
|
||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||
text: str = ""
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
|
||||
|
||||
@ -88,8 +88,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
_node_data_cls = LLMNodeData
|
||||
_node_type = NodeType.LLM
|
||||
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
node_inputs = None
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
node_inputs: Optional[dict[str, Any]] = None
|
||||
process_data = None
|
||||
|
||||
try:
|
||||
@ -145,8 +145,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
query = query_variable.text
|
||||
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
user_query=query,
|
||||
user_files=files,
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
@ -196,7 +196,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
@ -206,7 +205,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
process_data=process_data,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
|
||||
@ -302,7 +300,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
return messages
|
||||
|
||||
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
||||
variables = {}
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
if not node_data.prompt_config:
|
||||
return variables
|
||||
@ -319,7 +317,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
"""
|
||||
# check if it's a context structure
|
||||
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
|
||||
return input_dict["content"]
|
||||
return str(input_dict["content"])
|
||||
|
||||
# else, parse the dict
|
||||
try:
|
||||
@ -545,8 +543,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
def _fetch_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
user_query: str | None = None,
|
||||
user_files: Sequence["File"],
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence["File"],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
@ -557,12 +555,13 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
||||
prompt_messages = []
|
||||
# FIXME: fix the type error cause prompt_messages is type quick a few times
|
||||
prompt_messages: list[Any] = []
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
prompt_messages.extend(
|
||||
_handle_list_messages(
|
||||
self._handle_list_messages(
|
||||
messages=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
@ -581,14 +580,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
prompt_messages.extend(memory_messages)
|
||||
|
||||
# Add current query to the prompt messages
|
||||
if user_query:
|
||||
if sys_query:
|
||||
message = LLMNodeChatModelMessage(
|
||||
text=user_query,
|
||||
text=sys_query,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
prompt_messages.extend(
|
||||
_handle_list_messages(
|
||||
self._handle_list_messages(
|
||||
messages=[message],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
@ -635,24 +634,27 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
raise ValueError("Invalid prompt content type")
|
||||
|
||||
# Add current query to the prompt message
|
||||
if user_query:
|
||||
if sys_query:
|
||||
if prompt_content_type == str:
|
||||
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
|
||||
prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif prompt_content_type == list:
|
||||
for content_item in prompt_content:
|
||||
if content_item.type == PromptMessageContentType.TEXT:
|
||||
content_item.data = user_query + "\n" + content_item.data
|
||||
content_item.data = sys_query + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
else:
|
||||
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
|
||||
|
||||
if vision_enabled and user_files:
|
||||
# The sys_files will be deprecated later
|
||||
if vision_enabled and sys_files:
|
||||
file_prompts = []
|
||||
for file in user_files:
|
||||
for file in sys_files:
|
||||
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
file_prompts.append(file_prompt)
|
||||
# If last prompt is a user prompt, add files into its contents,
|
||||
# otherwise append a new user prompt
|
||||
if (
|
||||
len(prompt_messages) > 0
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
@ -662,7 +664,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# Filter prompt messages
|
||||
# Remove empty messages and filter unsupported content
|
||||
filtered_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message.content, list):
|
||||
@ -780,7 +782,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
else:
|
||||
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
|
||||
|
||||
variable_mapping = {}
|
||||
variable_mapping: dict[str, Any] = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
@ -846,6 +848,68 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
},
|
||||
}
|
||||
|
||||
def _handle_list_messages(
|
||||
self,
|
||||
*,
|
||||
messages: Sequence[LLMNodeChatModelMessage],
|
||||
context: Optional[str],
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
) -> Sequence[PromptMessage]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
if message.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinjia2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
else:
|
||||
# Get segment group from basic message
|
||||
if context:
|
||||
template = message.text.replace("{#context#}", context)
|
||||
else:
|
||||
template = message.text
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
|
||||
# Process segments for images
|
||||
file_contents = []
|
||||
for segment in segment_group.value:
|
||||
if isinstance(segment, ArrayFileSegment):
|
||||
for file in segment.value:
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
elif isinstance(segment, FileSegment):
|
||||
file = segment.value
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
|
||||
# Create message with text from all segments
|
||||
plain_text = segment_group.text
|
||||
if plain_text:
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
if file_contents:
|
||||
# Create message with image contents
|
||||
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
||||
match role:
|
||||
@ -880,68 +944,6 @@ def _render_jinja2_message(
|
||||
return result_text
|
||||
|
||||
|
||||
def _handle_list_messages(
|
||||
*,
|
||||
messages: Sequence[LLMNodeChatModelMessage],
|
||||
context: Optional[str],
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
) -> Sequence[PromptMessage]:
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if message.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinjia2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
else:
|
||||
# Get segment group from basic message
|
||||
if context:
|
||||
template = message.text.replace("{#context#}", context)
|
||||
else:
|
||||
template = message.text
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
|
||||
# Process segments for images
|
||||
file_contents = []
|
||||
for segment in segment_group.value:
|
||||
if isinstance(segment, ArrayFileSegment):
|
||||
for file in segment.value:
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
if isinstance(segment, FileSegment):
|
||||
file = segment.value
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
|
||||
# Create message with text from all segments
|
||||
plain_text = segment_group.text
|
||||
if plain_text:
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
if file_contents:
|
||||
# Create message with image contents
|
||||
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _calculate_rest_token(
|
||||
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
) -> int:
|
||||
@ -978,7 +980,7 @@ def _handle_memory_chat_mode(
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> Sequence[PromptMessage]:
|
||||
memory_messages = []
|
||||
memory_messages: Sequence[PromptMessage] = []
|
||||
# Get messages from memory for chat model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||
|
||||
@ -14,8 +14,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
_node_data_cls = LoopNodeData
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
def _run(self) -> LoopState:
|
||||
return super()._run()
|
||||
def _run(self) -> LoopState: # type: ignore
|
||||
return super()._run() # type: ignore
|
||||
|
||||
@classmethod
|
||||
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
|
||||
@ -28,7 +28,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
|
||||
# TODO waiting for implementation
|
||||
return [
|
||||
Condition(
|
||||
Condition( # type: ignore
|
||||
variable_selector=[node_id, "index"],
|
||||
comparison_operator="≤",
|
||||
value_type="value_selector",
|
||||
|
||||
@ -25,7 +25,7 @@ class ParameterConfig(BaseModel):
|
||||
raise ValueError("Parameter name is required")
|
||||
if value in {"__reason", "__is_success"}:
|
||||
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
@ -52,7 +52,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
|
||||
:return: parameter json schema
|
||||
"""
|
||||
parameters = {"type": "object", "properties": {}, "required": []}
|
||||
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
for parameter in self.parameters:
|
||||
parameter_schema: dict[str, Any] = {"description": parameter.description}
|
||||
|
||||
@ -63,7 +63,8 @@ class ParameterExtractorNode(LLMNode):
|
||||
Parameter Extractor Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = ParameterExtractorNodeData
|
||||
# FIXME: figure out why here is different from super class
|
||||
_node_data_cls = ParameterExtractorNodeData # type: ignore
|
||||
_node_type = NodeType.PARAMETER_EXTRACTOR
|
||||
|
||||
_model_instance: Optional[ModelInstance] = None
|
||||
@ -179,6 +180,15 @@ class ParameterExtractorNode(LLMNode):
|
||||
error=str(e),
|
||||
metadata={},
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)},
|
||||
error=str(e),
|
||||
metadata={},
|
||||
)
|
||||
|
||||
error = None
|
||||
|
||||
@ -244,6 +254,9 @@ class ParameterExtractorNode(LLMNode):
|
||||
# deduct quota
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
if text is None:
|
||||
text = ""
|
||||
|
||||
return text, usage, tool_call
|
||||
|
||||
def _generate_function_call_prompt(
|
||||
@ -596,9 +609,10 @@ class ParameterExtractorNode(LLMNode):
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
return cast(dict, json.loads(json_str))
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
|
||||
"""
|
||||
@ -607,13 +621,13 @@ class ParameterExtractorNode(LLMNode):
|
||||
if not tool_call or not tool_call.function.arguments:
|
||||
return None
|
||||
|
||||
return json.loads(tool_call.function.arguments)
|
||||
return cast(dict, json.loads(tool_call.function.arguments))
|
||||
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
|
||||
"""
|
||||
Generate default result.
|
||||
"""
|
||||
result = {}
|
||||
result: dict[str, Any] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.type == "number":
|
||||
result[parameter.name] = 0
|
||||
@ -763,7 +777,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
node_data: ParameterExtractorNodeData, # type: ignore
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@ -772,6 +786,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# FIXME: fix the type error later
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
|
||||
|
||||
if node_data.instruction:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters"
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy.
|
||||
@ -35,7 +37,7 @@ FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information fr
|
||||
</structure>
|
||||
""" # noqa: E501
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [
|
||||
FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [
|
||||
{
|
||||
"user": {
|
||||
"query": "What is the weather today in SF?",
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
@ -36,12 +34,9 @@ from .template_prompts import (
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_3,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file import File
|
||||
|
||||
|
||||
class QuestionClassifierNode(LLMNode):
|
||||
_node_data_cls = QuestionClassifierNodeData
|
||||
_node_data_cls = QuestionClassifierNodeData # type: ignore
|
||||
_node_type = NodeType.QUESTION_CLASSIFIER
|
||||
|
||||
def _run(self):
|
||||
@ -66,7 +61,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
node_data.instruction = variable_pool.convert_template(
|
||||
node_data.instruction).text
|
||||
|
||||
files: Sequence[File] = (
|
||||
files = (
|
||||
self._fetch_files(
|
||||
selector=node_data.vision.configs.variable_selector,
|
||||
)
|
||||
@ -89,37 +84,38 @@ class QuestionClassifierNode(LLMNode):
|
||||
)
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
user_query=query,
|
||||
sys_query=query,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
user_files=files,
|
||||
sys_files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
break
|
||||
|
||||
category_name = node_data.classes[0].name
|
||||
category_id = node_data.classes[0].id
|
||||
try:
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
break
|
||||
|
||||
category_name = node_data.classes[0].name
|
||||
category_id = node_data.classes[0].id
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
# result_text_json = json.loads(result_text.strip('```JSON\n'))
|
||||
if "category_name" in result_text_json and "category_id" in result_text_json:
|
||||
@ -130,10 +126,6 @@ class QuestionClassifierNode(LLMNode):
|
||||
if category_id_result in category_ids:
|
||||
category_name = classes_map[category_id_result]
|
||||
category_id = category_id_result
|
||||
|
||||
except OutputParserError:
|
||||
logging.exception(f"Failed to parse result text: {result_text}")
|
||||
try:
|
||||
process_data = {
|
||||
"model_mode": model_config.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
@ -157,7 +149,6 @@ class QuestionClassifierNode(LLMNode):
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -177,7 +168,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
node_data: Any,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@ -186,6 +177,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(QuestionClassifierNodeData, node_data)
|
||||
variable_mapping = {"query": node_data.query_variable_selector}
|
||||
variable_selectors = []
|
||||
if node_data.instruction:
|
||||
|
||||
@ -9,7 +9,6 @@ from core.file import File, FileTransferMethod, FileType
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
@ -58,6 +57,8 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
|
||||
# get tool runtime
|
||||
try:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
|
||||
)
|
||||
@ -145,7 +146,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
"""
|
||||
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
|
||||
|
||||
result = {}
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
parameter = tool_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
class VariableOperatorNodeError(Exception):
|
||||
class VariableOperatorNodeError(ValueError):
|
||||
"""Base error type, don't use directly."""
|
||||
|
||||
pass
|
||||
|
||||
@ -36,6 +36,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
|
||||
case WriteMode.CLEAR:
|
||||
income_value = get_zero_value(original_variable.value_type)
|
||||
if income_value is None:
|
||||
raise VariableOperatorNodeError("income value not found")
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
|
||||
case _:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
@ -29,7 +29,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
inputs = self.node_data.model_dump()
|
||||
process_data = {}
|
||||
process_data: dict[str, Any] = {}
|
||||
# NOTE: This node has no outputs
|
||||
updated_variables: list[Variable] = []
|
||||
|
||||
@ -119,7 +119,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
else:
|
||||
conversation_id = conversation_id.value
|
||||
common_helpers.update_conversation_variable(
|
||||
conversation_id=conversation_id,
|
||||
conversation_id=cast(str, conversation_id),
|
||||
variable=variable,
|
||||
)
|
||||
|
||||
|
||||
@ -129,11 +129,11 @@ class WorkflowEntry:
|
||||
:return:
|
||||
"""
|
||||
# fetch node info from workflow graph
|
||||
graph = workflow.graph_dict
|
||||
if not graph:
|
||||
workflow_graph = workflow.graph_dict
|
||||
if not workflow_graph:
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
nodes = graph.get("nodes")
|
||||
nodes = workflow_graph.get("nodes")
|
||||
if not nodes:
|
||||
raise ValueError("nodes not found in workflow graph")
|
||||
|
||||
@ -297,7 +297,8 @@ class WorkflowEntry:
|
||||
|
||||
@staticmethod
|
||||
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
|
||||
return WorkflowEntry._handle_special_values(value)
|
||||
result = WorkflowEntry._handle_special_values(value)
|
||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
||||
|
||||
@staticmethod
|
||||
def _handle_special_values(value: Any) -> Any:
|
||||
@ -309,10 +310,10 @@ class WorkflowEntry:
|
||||
res[k] = WorkflowEntry._handle_special_values(v)
|
||||
return res
|
||||
if isinstance(value, list):
|
||||
res = []
|
||||
res_list = []
|
||||
for item in value:
|
||||
res.append(WorkflowEntry._handle_special_values(item))
|
||||
return res
|
||||
res_list.append(WorkflowEntry._handle_special_values(item))
|
||||
return res_list
|
||||
if isinstance(value, File):
|
||||
return value.to_dict()
|
||||
return value
|
||||
|
||||
Reference in New Issue
Block a user