Merge fix/chore-fix into dev/plugin-deploy

This commit is contained in:
Yeuoly
2024-12-25 15:12:05 +08:00
733 changed files with 7774 additions and 5226 deletions

View File

@ -33,7 +33,7 @@ _TEXT_COLOR_MAPPING = {
class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None:
self.current_node_id = None
self.current_node_id: Optional[str] = None
def on_event(self, event: GraphEngineEvent) -> None:
if isinstance(event, GraphRunStartedEvent):

View File

@ -37,12 +37,15 @@ class NodeRunResult(BaseModel):
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[dict[str, Any]] = None # process data
process_data: Optional[Mapping[str, Any]] = None # process data
outputs: Optional[Mapping[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata
llm_usage: Optional[LLMUsage] = None # llm usage
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
error_type: Optional[str] = None # error type if status is failed
# single step node run retry
retry_index: int = 0

View File

@ -5,7 +5,7 @@ from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState):
"""
Check if the condition can be executed

View File

@ -33,7 +33,7 @@ class GraphRunSucceededEvent(BaseGraphEvent):
class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
exceptions_count: Optional[int] = Field(description="exception count", default=0)
exceptions_count: int = Field(description="exception count", default=0)
class GraphRunPartialSucceededEvent(BaseGraphEvent):
@ -97,6 +97,12 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeRunRetryEvent(NodeRunStartedEvent):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="which retry attempt is about to be performed")
start_at: datetime = Field(..., description="retry start time")
###########################################
# Parallel Branch Events
###########################################

View File

@ -1,9 +1,11 @@
import uuid
from collections import defaultdict
from collections.abc import Mapping
from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from configs import dify_config
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
@ -170,7 +172,9 @@ class Graph(BaseModel):
for parallel in parallel_mapping.values():
if parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id
parallel_mapping=parallel_mapping,
level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
parent_parallel_id=parallel.parent_parallel_id,
)
# init answer stream generate routes
@ -307,26 +311,17 @@ class Graph(BaseModel):
parallel = None
if len(target_node_edges) > 1:
# fetch all node ids in current parallels
parallel_branch_node_ids = {}
condition_edge_mappings = {}
parallel_branch_node_ids = defaultdict(list)
condition_edge_mappings = defaultdict(list)
for graph_edge in target_node_edges:
if graph_edge.run_condition is None:
if "default" not in parallel_branch_node_ids:
parallel_branch_node_ids["default"] = []
parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
else:
condition_hash = graph_edge.run_condition.hash
if condition_hash not in condition_edge_mappings:
condition_edge_mappings[condition_hash] = []
condition_edge_mappings[condition_hash].append(graph_edge)
for condition_hash, graph_edges in condition_edge_mappings.items():
if len(graph_edges) > 1:
if condition_hash not in parallel_branch_node_ids:
parallel_branch_node_ids[condition_hash] = []
for graph_edge in graph_edges:
parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)
@ -415,7 +410,7 @@ class Graph(BaseModel):
if condition_edge_mappings:
for condition_hash, graph_edges in condition_edge_mappings.items():
for graph_edge in graph_edges:
current_parallel: GraphParallel | None = cls._get_current_parallel(
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=condition_parallels.get(condition_hash),

View File

@ -6,6 +6,7 @@ import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from datetime import UTC, datetime
from typing import Any, Optional, cast
from flask import Flask, current_app
@ -26,6 +27,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
@ -39,6 +41,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
@ -65,7 +68,7 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
self.max_submit_count = max_submit_count
self.submit_count = 0
def submit(self, fn, *args, **kwargs):
def submit(self, fn, /, *args, **kwargs):
self.submit_count += 1
self.check_is_full()
@ -139,7 +142,8 @@ class GraphEngine:
def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event
yield GraphRunStartedEvent()
handle_exceptions = []
handle_exceptions: list[str] = []
stream_processor: StreamProcessor
try:
if self.init_params.workflow_type == WorkflowType.CHAT:
@ -349,7 +353,7 @@ class GraphEngine:
if any(edge.run_condition for edge in edge_mappings):
# if nodes has run conditions, get node id which branch to take based on the run condition results
condition_edge_mappings = {}
condition_edge_mappings: dict[str, list[GraphEdge]] = {}
for edge in edge_mappings:
if edge.run_condition:
run_condition_hash = edge.run_condition.hash
@ -363,6 +367,9 @@ class GraphEngine:
continue
edge = cast(GraphEdge, sub_edge_mappings[0])
if edge.run_condition is None:
logger.warning(f"Edge {edge.target_node_id} run condition is None")
continue
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
@ -386,11 +393,11 @@ class GraphEngine:
handle_exceptions=handle_exceptions,
)
for item in parallel_generator:
if isinstance(item, str):
final_node_id = item
for parallel_result in parallel_generator:
if isinstance(parallel_result, str):
final_node_id = parallel_result
else:
yield item
yield parallel_result
break
@ -412,11 +419,11 @@ class GraphEngine:
handle_exceptions=handle_exceptions,
)
for item in parallel_generator:
if isinstance(item, str):
final_node_id = item
for generated_item in parallel_generator:
if isinstance(generated_item, str):
final_node_id = generated_item
else:
yield item
yield generated_item
if not final_node_id:
break
@ -587,7 +594,7 @@ class GraphEngine:
def _run_node(
self,
node_instance: BaseNode,
node_instance: BaseNode[BaseNodeData],
route_node_state: RouteNodeState,
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
@ -613,36 +620,120 @@ class GraphEngine:
)
db.session.close()
max_retries = node_instance.node_data.retry_config.max_retries
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
retries = 0
should_continue_retry = True
while should_continue_retry and retries <= max_retries:
try:
# run node
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
generator = node_instance.run()
for item in generator:
if isinstance(item, GraphEngineEvent):
if isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id
try:
# run node
generator = node_instance.run()
for item in generator:
if isinstance(item, GraphEngineEvent):
if isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id
yield item
else:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if (
retries == max_retries
and node_instance.node_type == NodeType.HTTP_REQUEST
and run_result.outputs
and not node_instance.should_continue_on_error
):
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
if node_instance.should_retry and retries < max_retries:
retries += 1
route_node_state.node_run_result = run_result
yield NodeRunRetryEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
predecessor_node_id=node_instance.previous_node_id,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
error=run_result.error or "Unknown error",
retry_index=retries,
start_at=retry_start_at,
)
time.sleep(retry_interval)
continue
route_node_state.set_finished(run_result=run_result)
yield item
else:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
should_continue_retry = False
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
should_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.llm_usage:
# use the latest usage
self.graph_runtime_state.llm_usage += run_result.llm_usage
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
@ -651,133 +742,86 @@ class GraphEngine:
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
if run_result.llm_usage:
# use the latest usage
self.graph_runtime_state.llm_usage += run_result.llm_usage
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
# When setting metadata, convert to dict first
if not run_result.metadata:
run_result.metadata = {}
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
if parent_parallel_id and parent_parallel_start_node_id:
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)
if parallel_id and parallel_start_node_id:
metadata_dict = dict(run_result.metadata)
metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
if parent_parallel_id and parent_parallel_start_node_id:
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)
run_result.metadata = metadata_dict
yield NodeRunSucceededEvent(
yield NodeRunSucceededEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
should_continue_retry = False
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
except GenerateTaskStoppedError:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
except GenerateTaskStoppedError:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
"""
@ -836,8 +880,8 @@ class GraphEngine:
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
# add error message to handle_exceptions
handle_exceptions.append(error_result.error)
node_error_args = {
handle_exceptions.append(error_result.error or "")
node_error_args: dict[str, Any] = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": error_result.error,
"inputs": error_result.inputs,

View File

@ -147,6 +147,8 @@ class AnswerStreamGeneratorRouter:
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
if (

View File

@ -60,11 +60,10 @@ class AnswerStreamProcessor(StreamProcessor):
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
# remove unreachable nodes
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(event)
yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event))
else:
yield event
@ -131,7 +130,7 @@ class AnswerStreamProcessor(StreamProcessor):
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=value_selector,
from_variable_selector=list(value_selector),
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,

View File

@ -1,10 +1,13 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
logger = logging.getLogger(__name__)
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
@ -16,7 +19,7 @@ class StreamProcessor(ABC):
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
@ -29,15 +32,24 @@ class StreamProcessor(ABC):
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
reachable_node_ids: list[str] = []
unreachable_first_node_ids: list[str] = []
if finished_node_id not in self.graph.edge_mapping:
logger.warning(f"node {finished_node_id} has no edge mapping")
return
for edge in self.graph.edge_mapping[finished_node_id]:
if (
edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify
):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
# remove unreachable nodes
# FIXME: because of the code branch can combine directly, so for answer node
# we remove the node maybe shortcut the answer node, so comment this code for now
# there is not effect on the answer node and the workflow, when we have a better solution
# we can open this code. Issues: #11542 #9560 #10638 #10564
# reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)

View File

@ -38,7 +38,8 @@ class DefaultValue(BaseModel):
@staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
@staticmethod
def _convert_number(value: str) -> float:
@ -84,7 +85,7 @@ class DefaultValue(BaseModel):
},
}
validator = type_validators.get(self.type)
validator: dict[str, Any] = type_validators.get(self.type, {})
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
@ -106,12 +107,25 @@ class DefaultValue(BaseModel):
return self
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
version: str = "1"
retry_config: RetryConfig = RetryConfig()
@property
def default_value_dict(self):

View File

@ -1,4 +1,4 @@
class BaseNodeError(Exception):
class BaseNodeError(ValueError):
"""Base class for node errors."""
pass

View File

@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from models.workflow import WorkflowNodeExecutionStatus
@ -72,7 +72,11 @@ class BaseNode(Generic[GenericNodeData]):
result = self._run()
except Exception as e:
logger.exception(f"Node {self.node_id} failed to run")
result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError")
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
error_type="WorkflowNodeError",
)
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(run_result=result)
@ -143,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]):
bool: if should continue on error
"""
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
@property
def should_retry(self) -> bool:
"""judge if should retry
Returns:
bool: if should retry
"""
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union
from typing import Any, Optional
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
@ -59,7 +59,7 @@ class CodeNode(BaseNode[CodeNodeData]):
)
# Transform result
result = self._transform_result(result, self.node_data.outputs)
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
@ -67,18 +67,17 @@ class CodeNode(BaseNode[CodeNodeData]):
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
def _check_string(self, value: str, variable: str) -> str:
def _check_string(self, value: str | None, variable: str) -> str | None:
"""
Check string
:param value: value
:param variable: variable
:return:
"""
if value is None:
return None
if not isinstance(value, str):
if value is None:
return None
else:
raise OutputValidationError(f"Output variable `{variable}` must be a string")
raise OutputValidationError(f"Output variable `{variable}` must be a string")
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
raise OutputValidationError(
@ -88,18 +87,17 @@ class CodeNode(BaseNode[CodeNodeData]):
return value.replace("\x00", "")
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
"""
Check number
:param value: value
:param variable: variable
:return:
"""
if value is None:
return None
if not isinstance(value, int | float):
if value is None:
return None
else:
raise OutputValidationError(f"Output variable `{variable}` must be a number")
raise OutputValidationError(f"Output variable `{variable}` must be a number")
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
raise OutputValidationError(
@ -118,18 +116,16 @@ class CodeNode(BaseNode[CodeNodeData]):
return value
def _transform_result(
self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1
) -> dict:
"""
Transform result
:param result: result
:param output_schema: output schema
:return:
"""
self,
result: Mapping[str, Any],
output_schema: Optional[dict[str, CodeNodeData.Output]],
prefix: str = "",
depth: int = 1,
):
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
transformed_result = {}
transformed_result: dict[str, Any] = {}
if output_schema is None:
# validate output thought instance type
for output_name, output_value in result.items():

View File

@ -14,7 +14,7 @@ class CodeNodeData(BaseNodeData):
class Output(BaseModel):
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
children: Optional[dict[str, "Output"]] = None
children: Optional[dict[str, "CodeNodeData.Output"]] = None
class Dependency(BaseModel):
name: str

View File

@ -1,19 +1,15 @@
import csv
import io
import json
import logging
import os
import tempfile
from typing import cast
import docx
import pandas as pd
import pypdfium2 # type: ignore
import yaml # type: ignore
from unstructured.partition.api import partition_via_api
from unstructured.partition.email import partition_email
from unstructured.partition.epub import partition_epub
from unstructured.partition.msg import partition_msg
from unstructured.partition.ppt import partition_ppt
from unstructured.partition.pptx import partition_pptx
from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
@ -28,6 +24,8 @@ from models.workflow import WorkflowNodeExecutionStatus
from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
logger = logging.getLogger(__name__)
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
"""
@ -162,7 +160,7 @@ def _extract_text_from_yaml(file_content: bytes) -> str:
"""Extract the content from yaml file"""
try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore"))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
except (UnicodeDecodeError, yaml.YAMLError) as e:
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
@ -183,10 +181,43 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
def _extract_text_from_doc(file_content: bytes) -> str:
"""
Extract text from a DOC/DOCX file.
For now support only paragraph and table add more if needed
"""
try:
doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file)
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
text = []
# Process paragraphs
for paragraph in doc.paragraphs:
if paragraph.text.strip():
text.append(paragraph.text)
# Process tables
for table in doc.tables:
# Table header
try:
# table maybe cause errors so ignore it.
if len(table.rows) > 0 and table.rows[0].cells is not None:
# Check if any cell in the table has text
has_content = False
for row in table.rows:
if any(cell.text.strip() for cell in row.cells):
has_content = True
break
if has_content:
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
for row in table.rows[1:]:
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
text.append(markdown_table)
except Exception as e:
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
continue
return "\n".join(text)
except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
@ -199,9 +230,9 @@ def _download_file_content(file: File) -> bytes:
raise FileDownloadError("Missing URL for remote file")
response = ssrf_proxy.get(file.remote_url)
response.raise_for_status()
return response.content
return cast(bytes, response.content)
else:
return file_manager.download(file)
return cast(bytes, file_manager.download(file))
except Exception as e:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
@ -256,6 +287,8 @@ def _extract_text_from_excel(file_content: bytes) -> str:
def _extract_text_from_ppt(file_content: bytes) -> str:
from unstructured.partition.ppt import partition_ppt
try:
with io.BytesIO(file_content) as file:
elements = partition_ppt(file=file)
@ -265,6 +298,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
def _extract_text_from_pptx(file_content: bytes) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.pptx import partition_pptx
try:
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
@ -287,6 +323,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
def _extract_text_from_epub(file_content: bytes) -> str:
from unstructured.partition.epub import partition_epub
try:
with io.BytesIO(file_content) as file:
elements = partition_epub(file=file)
@ -296,6 +334,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:
def _extract_text_from_eml(file_content: bytes) -> str:
from unstructured.partition.email import partition_email
try:
with io.BytesIO(file_content) as file:
elements = partition_email(file=file)
@ -305,6 +345,8 @@ def _extract_text_from_eml(file_content: bytes) -> str:
def _extract_text_from_msg(file_content: bytes) -> str:
from unstructured.partition.msg import partition_msg
try:
with io.BytesIO(file_content) as file:
elements = partition_msg(file=file)

View File

@ -67,7 +67,7 @@ class EndStreamGeneratorRouter:
and node_type == NodeType.LLM.value
and variable_selector.value_selector[1] == "text"
):
value_selectors.append(variable_selector.value_selector)
value_selectors.append(list(variable_selector.value_selector))
return value_selectors
@ -119,8 +119,7 @@ class EndStreamGeneratorRouter:
current_node_id: str,
end_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]],
# type: ignore[name-defined]
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
end_dependencies: dict[str, list[str]],
) -> None:
"""
@ -135,6 +134,8 @@ class EndStreamGeneratorRouter:
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in {
NodeType.IF_ELSE.value,

View File

@ -23,7 +23,7 @@ class EndStreamProcessor(StreamProcessor):
self.route_position[end_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
self.has_output = False
self.output_node_ids = set()
self.output_node_ids: set[str] = set()
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
for event in generator:

View File

@ -36,3 +36,4 @@ class FailBranchSourceHandle(StrEnum):
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE

View File

@ -1,4 +1,10 @@
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from .event import (
ModelInvokeCompletedEvent,
RunCompletedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
RunStreamChunkEvent,
)
from .types import NodeEvent
__all__ = [
@ -6,5 +12,6 @@ __all__ = [
"NodeEvent",
"RunCompletedEvent",
"RunRetrieverResourceEvent",
"RunRetryEvent",
"RunStreamChunkEvent",
]

View File

@ -1,7 +1,10 @@
from datetime import datetime
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus
class RunCompletedEvent(BaseModel):
@ -26,3 +29,19 @@ class ModelInvokeCompletedEvent(BaseModel):
text: str
usage: LLMUsage
finish_reason: str | None = None
class RunRetryEvent(BaseModel):
"""Node Run Retry event"""
error: str = Field(..., description="error")
retry_index: int = Field(..., description="Retry attempt number")
start_at: datetime = Field(..., description="Retry start time")
class SingleStepRetryEvent(NodeRunResult):
"""Single step retry event"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RETRY
elapsed_time: float = Field(..., description="elapsed time")

View File

@ -16,3 +16,7 @@ class InvalidHttpMethodError(HttpRequestNodeError):
class ResponseSizeError(HttpRequestNodeError):
"""Raised when the response size exceeds the allowed threshold."""
class RequestBodyError(HttpRequestNodeError):
"""Raised when the request body is invalid."""

View File

@ -23,6 +23,7 @@ from .exc import (
FileFetchError,
HttpRequestNodeError,
InvalidHttpMethodError,
RequestBodyError,
ResponseSizeError,
)
@ -45,6 +46,7 @@ class Executor:
headers: dict[str, str]
auth: HttpRequestNodeAuthorization
timeout: HttpRequestNodeTimeout
max_retries: int
boundary: str
@ -54,6 +56,7 @@ class Executor:
node_data: HttpRequestNodeData,
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@ -73,6 +76,7 @@ class Executor:
self.files = None
self.data = None
self.json = None
self.max_retries = max_retries
# init template
self.variable_pool = variable_pool
@ -103,9 +107,9 @@ class Executor:
if not (key := key.strip()):
continue
value = value[0].strip() if value else ""
value_str = value[0].strip() if value else ""
result.append(
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value).text)
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text)
)
self.params = result
@ -140,13 +144,19 @@ class Executor:
case "none":
self.content = ""
case "raw-text":
if len(data) != 1:
raise RequestBodyError("raw-text body type should have exactly one item")
self.content = self.variable_pool.convert_template(data[0].value).text
case "json":
if len(data) != 1:
raise RequestBodyError("json body type should have exactly one item")
json_string = self.variable_pool.convert_template(data[0].value).text
json_object = json.loads(json_string, strict=False)
self.json = json_object
# self.json = self._parse_object_contains_variables(json_object)
case "binary":
if len(data) != 1:
raise RequestBodyError("binary body type should have exactly one item")
file_selector = data[0].file
file_variable = self.variable_pool.get_file(file_selector)
if file_variable is None:
@ -172,9 +182,10 @@ class Executor:
self.variable_pool.convert_template(item.key).text: item.file
for item in filter(lambda item: item.type == "file", data)
}
files: dict[str, Any] = {}
files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()}
files = {k: v for k, v in files.items() if v is not None}
files = {k: variable.value for k, variable in files.items()}
files = {k: variable.value for k, variable in files.items() if variable is not None}
files = {
k: (v.filename, file_manager.download(v), v.mime_type or "application/octet-stream")
for k, v in files.items()
@ -241,13 +252,15 @@ class Executor:
"params": self.params,
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
"follow_redirects": True,
"max_retries": self.max_retries,
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response = getattr(ssrf_proxy, self.method)(**request_args)
except ssrf_proxy.MaxRetriesExceededError as e:
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e))
return response
# FIXME: fix type ignore, this maybe httpx type issue
return response # type: ignore
def invoke(self) -> Response:
# assemble headers
@ -289,35 +302,37 @@ class Executor:
continue
raw += f"{k}: {v}\r\n"
body = ""
body_string = ""
if self.files:
for k, v in self.files.items():
body += f"--{boundary}\r\n"
body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
body += f"{v[1]}\r\n"
body += f"--{boundary}--\r\n"
body_string += f"--{boundary}\r\n"
body_string += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
body_string += f"{v[1]}\r\n"
body_string += f"--{boundary}--\r\n"
elif self.node_data.body:
if self.content:
if isinstance(self.content, str):
body = self.content
body_string = self.content
elif isinstance(self.content, bytes):
body = self.content.decode("utf-8", errors="replace")
body_string = self.content.decode("utf-8", errors="replace")
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
body = urlencode(self.data)
body_string = urlencode(self.data)
elif self.data and self.node_data.body.type == "form-data":
for key, value in self.data.items():
body += f"--{boundary}\r\n"
body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
body += f"{value}\r\n"
body += f"--{boundary}--\r\n"
body_string += f"--{boundary}\r\n"
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
body_string += f"{value}\r\n"
body_string += f"--{boundary}--\r\n"
elif self.json:
body = json.dumps(self.json)
body_string = json.dumps(self.json)
elif self.node_data.body.type == "raw-text":
body = self.node_data.body.data[0].value
if body:
raw += f"Content-Length: {len(body)}\r\n"
if len(self.node_data.body.data) != 1:
raise RequestBodyError("raw-text body type should have exactly one item")
body_string = self.node_data.body.data[0].value
if body_string:
raw += f"Content-Length: {len(body_string)}\r\n"
raw += "\r\n" # Empty line between headers and body
raw += body
raw += body_string
return raw

View File

@ -1,6 +1,7 @@
import logging
import mimetypes
from collections.abc import Mapping, Sequence
from typing import Any
from typing import Any, Optional
from configs import dify_config
from core.file import File, FileTransferMethod
@ -19,7 +20,7 @@ from .entities import (
HttpRequestNodeTimeout,
Response,
)
from .exc import HttpRequestNodeError
from .exc import HttpRequestNodeError, RequestBodyError
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
@ -35,7 +36,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
_node_type = NodeType.HTTP_REQUEST
@classmethod
def get_default_config(cls, filters: dict | None = None) -> dict:
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
return {
"type": "http-request",
"config": {
@ -51,6 +52,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
},
},
"retry_config": {
"max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,
"retry_interval": 0.5 * (2**2),
"retry_enabled": True,
},
}
def _run(self) -> NodeRunResult:
@ -60,12 +66,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
node_data=self.node_data,
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
)
process_data["request"] = http_executor.to_log()
response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response)
if not response.response.is_success and self.should_continue_on_error:
if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
@ -129,9 +136,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
data = node_data.body.data
match body_type:
case "binary":
if len(data) != 1:
raise RequestBodyError("invalid body data, should have only one item")
selector = data[0].file
selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector))
case "json" | "raw-text":
if len(data) != 1:
raise RequestBodyError("invalid body data, should have only one item")
selectors += variable_template_parser.extract_selectors_from_template(data[0].key)
selectors += variable_template_parser.extract_selectors_from_template(data[0].value)
case "x-www-form-urlencoded":
@ -149,27 +160,31 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
)
mapping = {}
for selector in selectors:
mapping[node_id + "." + selector.variable] = selector.value_selector
for selector_iter in selectors:
mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector
return mapping
def extract_files(self, url: str, response: Response) -> list[File]:
"""
Extract files from response
Extract files from response by checking both Content-Type header and URL
"""
files = []
is_file = response.is_file
content_type = response.content_type
content = response.content
if is_file and content_type:
if is_file:
# Guess file extension from URL or Content-Type header
filename = url.split("?")[0].split("/")[-1] or ""
mime_type = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
tool_file = ToolFileManager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
file_binary=content,
mimetype=content_type,
mimetype=mime_type,
)
mapping = {

View File

@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
from flask import Flask, current_app
from configs import dify_config
from core.variables import IntegerVariable
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeRunResult,
@ -76,12 +76,15 @@ class IterationNode(BaseNode[IterationNodeData]):
"""
Run the node.
"""
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not iterator_list_segment:
raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found")
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
if len(iterator_list_segment.value) == 0:
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -90,7 +93,7 @@ class IterationNode(BaseNode[IterationNodeData]):
)
return
iterator_list_value = iterator_list_segment.to_object()
iterator_list_value = variable.to_object()
if not isinstance(iterator_list_value, list):
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
@ -360,13 +363,16 @@ class IterationNode(BaseNode[IterationNodeData]):
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
if self.node_data.is_parallel:
metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
else:
metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index
metadata = {
**metadata,
NodeRunMetadataKey.ITERATION_ID: self.node_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID
if self.node_data.is_parallel
else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id
if self.node_data.is_parallel
else iter_run_index,
}
event.route_node_state.node_run_result.metadata = metadata
return event

View File

@ -70,7 +70,20 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
except KnowledgeRetrievalNodeError as e:
logger.warning("Error when running knowledge retrieval node")
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
available_datasets = []
@ -134,6 +147,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
planning_strategy=planning_strategy,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
@ -144,6 +159,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
reranking_model = None
weights = None
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
if node_data.multiple_retrieval_config.weights is None:
raise ValueError("weights is required")
reranking_model = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
@ -160,18 +177,20 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
reranking_model = None
weights = None
all_documents = dataset_retrieval.multiple_retrieve(
self.app_id,
self.tenant_id,
self.user_id,
self.user_from.value,
available_datasets,
query,
node_data.multiple_retrieval_config.top_k,
node_data.multiple_retrieval_config.score_threshold,
node_data.multiple_retrieval_config.reranking_mode,
reranking_model,
weights,
node_data.multiple_retrieval_config.reranking_enable,
app_id=self.app_id,
tenant_id=self.tenant_id,
user_id=self.user_id,
user_from=self.user_from.value,
available_datasets=available_datasets,
query=query,
top_k=node_data.multiple_retrieval_config.top_k,
score_threshold=node_data.multiple_retrieval_config.score_threshold
if node_data.multiple_retrieval_config.score_threshold is not None
else 0.0,
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
reranking_model=reranking_model,
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
)
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
@ -192,7 +211,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
"content": item.page_content,
}
retrieval_resource_list.append(source)
document_score_list = {}
document_score_list: dict[str, float] = {}
# deal with dify documents
if dify_documents:
document_score_list = {}
@ -247,7 +266,9 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
retrieval_resource_list.append(source)
if retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list, key=lambda x: x.get("metadata").get("score") or 0.0, reverse=True
retrieval_resource_list,
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
reverse=True,
)
position = 1
for item in retrieval_resource_list:
@ -282,6 +303,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
:param node_data: node data
:return:
"""
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
model_name = node_data.single_retrieval_config.model.name
provider_name = node_data.single_retrieval_config.model.provider

View File

@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence
from typing import Literal, Union
from typing import Any, Literal, Union
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
@ -17,9 +17,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
_node_type = NodeType.LIST_OPERATOR
def _run(self):
inputs = {}
process_data = {}
outputs = {}
inputs: dict[str, list] = {}
process_data: dict[str, list] = {}
outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
if variable is None:
@ -93,6 +93,8 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
def _apply_filter(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
filter_func: Callable[[Any], bool]
result: list[Any] = []
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
@ -236,6 +238,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
@ -249,47 +252,47 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
raise InvalidKeyError(f"Invalid key: {key}")
def _contains(value: str):
def _contains(value: str) -> Callable[[str], bool]:
return lambda x: value in x
def _startswith(value: str):
def _startswith(value: str) -> Callable[[str], bool]:
return lambda x: x.startswith(value)
def _endswith(value: str):
def _endswith(value: str) -> Callable[[str], bool]:
return lambda x: x.endswith(value)
def _is(value: str):
def _is(value: str) -> Callable[[str], bool]:
return lambda x: x is value
def _in(value: str | Sequence[str]):
def _in(value: str | Sequence[str]) -> Callable[[str], bool]:
return lambda x: x in value
def _eq(value: int | float):
def _eq(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x == value
def _ne(value: int | float):
def _ne(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x != value
def _lt(value: int | float):
def _lt(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x < value
def _le(value: int | float):
def _le(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x <= value
def _gt(value: int | float):
def _gt(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x > value
def _ge(value: int | float):
def _ge(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x >= value
@ -302,6 +305,7 @@ def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
extract_func: Callable[[File], Any]
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
extract_func = _get_file_extract_string_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")

View File

@ -50,6 +50,7 @@ class PromptConfig(BaseModel):
class LLMNodeChatModelMessage(ChatModelMessage):
text: str = ""
jinja2_text: Optional[str] = None

View File

@ -88,8 +88,8 @@ class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]:
node_inputs = None
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
node_inputs: Optional[dict[str, Any]] = None
process_data = None
try:
@ -145,8 +145,8 @@ class LLMNode(BaseNode[LLMNodeData]):
query = query_variable.text
prompt_messages, stop = self._fetch_prompt_messages(
user_query=query,
user_files=files,
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
@ -196,7 +196,6 @@ class LLMNode(BaseNode[LLMNodeData]):
error_type=type(e).__name__,
)
)
return
except Exception as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
@ -206,7 +205,6 @@ class LLMNode(BaseNode[LLMNodeData]):
process_data=process_data,
)
)
return
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
@ -302,7 +300,7 @@ class LLMNode(BaseNode[LLMNodeData]):
return messages
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables = {}
variables: dict[str, Any] = {}
if not node_data.prompt_config:
return variables
@ -319,7 +317,7 @@ class LLMNode(BaseNode[LLMNodeData]):
"""
# check if it's a context structure
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
return input_dict["content"]
return str(input_dict["content"])
# else, parse the dict
try:
@ -545,8 +543,8 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_prompt_messages(
self,
*,
user_query: str | None = None,
user_files: Sequence["File"],
sys_query: str | None = None,
sys_files: Sequence["File"],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
@ -557,12 +555,13 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages = []
# FIXME: fix the type error cause prompt_messages is type quick a few times
prompt_messages: list[Any] = []
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
_handle_list_messages(
self._handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
@ -581,14 +580,14 @@ class LLMNode(BaseNode[LLMNodeData]):
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if user_query:
if sys_query:
message = LLMNodeChatModelMessage(
text=user_query,
text=sys_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
_handle_list_messages(
self._handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
@ -635,24 +634,27 @@ class LLMNode(BaseNode[LLMNodeData]):
raise ValueError("Invalid prompt content type")
# Add current query to the prompt message
if user_query:
if sys_query:
if prompt_content_type == str:
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
content_item.data = user_query + "\n" + content_item.data
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
if vision_enabled and user_files:
# The sys_files will be deprecated later
if vision_enabled and sys_files:
file_prompts = []
for file in user_files:
for file in sys_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
@ -662,7 +664,7 @@ class LLMNode(BaseNode[LLMNodeData]):
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Filter prompt messages
# Remove empty messages and filter unsupported content
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
@ -780,7 +782,7 @@ class LLMNode(BaseNode[LLMNodeData]):
else:
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
variable_mapping = {}
variable_mapping: dict[str, Any] = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
@ -846,6 +848,68 @@ class LLMNode(BaseNode[LLMNodeData]):
},
}
def _handle_list_messages(
self,
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages: list[PromptMessage] = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
elif isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
match role:
@ -880,68 +944,6 @@ def _render_jinja2_message(
return result_text
def _handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
if isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:
@ -978,7 +980,7 @@ def _handle_memory_chat_mode(
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
memory_messages = []
memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)

View File

@ -14,8 +14,8 @@ class LoopNode(BaseNode[LoopNodeData]):
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
def _run(self) -> LoopState:
return super()._run()
def _run(self) -> LoopState: # type: ignore
return super()._run() # type: ignore
@classmethod
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
@ -28,7 +28,7 @@ class LoopNode(BaseNode[LoopNodeData]):
# TODO waiting for implementation
return [
Condition(
Condition( # type: ignore
variable_selector=[node_id, "index"],
comparison_operator="",
value_type="value_selector",

View File

@ -25,7 +25,7 @@ class ParameterConfig(BaseModel):
raise ValueError("Parameter name is required")
if value in {"__reason", "__is_success"}:
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
return value
return str(value)
class ParameterExtractorNodeData(BaseNodeData):
@ -52,7 +52,7 @@ class ParameterExtractorNodeData(BaseNodeData):
:return: parameter json schema
"""
parameters = {"type": "object", "properties": {}, "required": []}
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description}

View File

@ -63,7 +63,8 @@ class ParameterExtractorNode(LLMNode):
Parameter Extractor Node.
"""
_node_data_cls = ParameterExtractorNodeData
# FIXME: figure out why here is different from super class
_node_data_cls = ParameterExtractorNodeData # type: ignore
_node_type = NodeType.PARAMETER_EXTRACTOR
_model_instance: Optional[ModelInstance] = None
@ -179,6 +180,15 @@ class ParameterExtractorNode(LLMNode):
error=str(e),
metadata={},
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)},
error=str(e),
metadata={},
)
error = None
@ -244,6 +254,9 @@ class ParameterExtractorNode(LLMNode):
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
if text is None:
text = ""
return text, usage, tool_call
def _generate_function_call_prompt(
@ -596,9 +609,10 @@ class ParameterExtractorNode(LLMNode):
json_str = extract_json(result[idx:])
if json_str:
try:
return json.loads(json_str)
return cast(dict, json.loads(json_str))
except Exception:
pass
return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
"""
@ -607,13 +621,13 @@ class ParameterExtractorNode(LLMNode):
if not tool_call or not tool_call.function.arguments:
return None
return json.loads(tool_call.function.arguments)
return cast(dict, json.loads(tool_call.function.arguments))
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
"""
Generate default result.
"""
result = {}
result: dict[str, Any] = {}
for parameter in data.parameters:
if parameter.type == "number":
result[parameter.name] = 0
@ -763,7 +777,7 @@ class ParameterExtractorNode(LLMNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ParameterExtractorNodeData,
node_data: ParameterExtractorNodeData, # type: ignore
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -772,6 +786,7 @@ class ParameterExtractorNode(LLMNode):
:param node_data: node data
:return:
"""
# FIXME: fix the type error later
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
if node_data.instruction:

View File

@ -1,3 +1,5 @@
from typing import Any
FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters"
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy.
@ -35,7 +37,7 @@ FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information fr
</structure>
""" # noqa: E501
FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [
FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [
{
"user": {
"query": "What is the weather today in SF?",

View File

@ -1,10 +1,8 @@
import json
import logging
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.llm_generator.output_parser.errors import OutputParserError
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
@ -36,12 +34,9 @@ from .template_prompts import (
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
if TYPE_CHECKING:
from core.file import File
class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData
_node_data_cls = QuestionClassifierNodeData # type: ignore
_node_type = NodeType.QUESTION_CLASSIFIER
def _run(self):
@ -66,7 +61,7 @@ class QuestionClassifierNode(LLMNode):
node_data.instruction = variable_pool.convert_template(
node_data.instruction).text
files: Sequence[File] = (
files = (
self._fetch_files(
selector=node_data.vision.configs.variable_selector,
)
@ -89,37 +84,38 @@ class QuestionClassifierNode(LLMNode):
)
prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template,
user_query=query,
sys_query=query,
memory=memory,
model_config=model_config,
user_files=files,
sys_files=files,
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
)
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
category_name = node_data.classes[0].name
category_id = node_data.classes[0].id
try:
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
category_name = node_data.classes[0].name
category_id = node_data.classes[0].id
result_text_json = parse_and_check_json_markdown(result_text, [])
# result_text_json = json.loads(result_text.strip('```JSON\n'))
if "category_name" in result_text_json and "category_id" in result_text_json:
@ -130,10 +126,6 @@ class QuestionClassifierNode(LLMNode):
if category_id_result in category_ids:
category_name = classes_map[category_id_result]
category_id = category_id_result
except OutputParserError:
logging.exception(f"Failed to parse result text: {result_text}")
try:
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
@ -157,7 +149,6 @@ class QuestionClassifierNode(LLMNode):
},
llm_usage=usage,
)
except ValueError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -177,7 +168,7 @@ class QuestionClassifierNode(LLMNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: QuestionClassifierNodeData,
node_data: Any,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -186,6 +177,7 @@ class QuestionClassifierNode(LLMNode):
:param node_data: node data
:return:
"""
node_data = cast(QuestionClassifierNodeData, node_data)
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors = []
if node_data.instruction:

View File

@ -9,7 +9,6 @@ from core.file import File, FileTransferMethod, FileType
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
@ -58,6 +57,8 @@ class ToolNode(BaseNode[ToolNodeData]):
# get tool runtime
try:
from core.tools.tool_manager import ToolManager
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
)
@ -145,7 +146,7 @@ class ToolNode(BaseNode[ToolNodeData]):
"""
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
result = {}
result: dict[str, Any] = {}
for parameter_name in node_data.tool_parameters:
parameter = tool_parameters_dictionary.get(parameter_name)
if not parameter:

View File

@ -1,4 +1,4 @@
class VariableOperatorNodeError(Exception):
class VariableOperatorNodeError(ValueError):
"""Base error type, don't use directly."""
pass

View File

@ -36,6 +36,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
if income_value is None:
raise VariableOperatorNodeError("income value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:

View File

@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, cast
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
@ -29,7 +29,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump()
process_data = {}
process_data: dict[str, Any] = {}
# NOTE: This node has no outputs
updated_variables: list[Variable] = []
@ -119,7 +119,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
else:
conversation_id = conversation_id.value
common_helpers.update_conversation_variable(
conversation_id=conversation_id,
conversation_id=cast(str, conversation_id),
variable=variable,
)

View File

@ -129,11 +129,11 @@ class WorkflowEntry:
:return:
"""
# fetch node info from workflow graph
graph = workflow.graph_dict
if not graph:
workflow_graph = workflow.graph_dict
if not workflow_graph:
raise ValueError("workflow graph not found")
nodes = graph.get("nodes")
nodes = workflow_graph.get("nodes")
if not nodes:
raise ValueError("nodes not found in workflow graph")
@ -297,7 +297,8 @@ class WorkflowEntry:
@staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
return WorkflowEntry._handle_special_values(value)
result = WorkflowEntry._handle_special_values(value)
return result if isinstance(result, Mapping) or result is None else dict(result)
@staticmethod
def _handle_special_values(value: Any) -> Any:
@ -309,10 +310,10 @@ class WorkflowEntry:
res[k] = WorkflowEntry._handle_special_values(v)
return res
if isinstance(value, list):
res = []
res_list = []
for item in value:
res.append(WorkflowEntry._handle_special_values(item))
return res
res_list.append(WorkflowEntry._handle_special_values(item))
return res_list
if isinstance(value, File):
return value.to_dict()
return value