mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 08:28:03 +08:00
remove bare list, dict, Sequence, None, Any (#25058)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@ -5,7 +5,7 @@ from core.workflow.graph_engine.entities.event import GraphEngineEvent
|
||||
|
||||
class WorkflowCallback(ABC):
|
||||
@abstractmethod
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Published event
|
||||
"""
|
||||
|
||||
@ -36,10 +36,10 @@ _TEXT_COLOR_MAPPING = {
|
||||
|
||||
|
||||
class WorkflowLoggingCallback(WorkflowCallback):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
self.current_node_id: Optional[str] = None
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self.print_text("\n[GraphRunStartedEvent]", color="pink")
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
@ -75,7 +75,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
else:
|
||||
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
|
||||
|
||||
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
|
||||
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent):
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
@ -84,7 +84,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="yellow")
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent):
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
@ -115,7 +115,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
color="green",
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent):
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
@ -143,7 +143,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
color="red",
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent):
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
@ -161,7 +161,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
|
||||
self.print_text(event.chunk_content, color="pink", end="")
|
||||
|
||||
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
|
||||
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent):
|
||||
"""
|
||||
Publish parallel started
|
||||
"""
|
||||
@ -173,9 +173,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
if event.in_loop_id:
|
||||
self.print_text(f"Loop ID: {event.in_loop_id}", color="blue")
|
||||
|
||||
def on_workflow_parallel_completed(
|
||||
self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
|
||||
) -> None:
|
||||
def on_workflow_parallel_completed(self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
|
||||
"""
|
||||
Publish parallel completed
|
||||
"""
|
||||
@ -200,14 +198,14 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
if isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self.print_text(f"Error: {event.error}", color=color)
|
||||
|
||||
def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
|
||||
def on_workflow_iteration_started(self, event: IterationRunStartedEvent):
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self.print_text("\n[IterationRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
||||
|
||||
def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
|
||||
def on_workflow_iteration_next(self, event: IterationRunNextEvent):
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
@ -215,7 +213,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
||||
self.print_text(f"Iteration Index: {event.index}", color="blue")
|
||||
|
||||
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
|
||||
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent):
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
@ -227,14 +225,14 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
)
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
|
||||
|
||||
def on_workflow_loop_started(self, event: LoopRunStartedEvent) -> None:
|
||||
def on_workflow_loop_started(self, event: LoopRunStartedEvent):
|
||||
"""
|
||||
Publish loop started
|
||||
"""
|
||||
self.print_text("\n[LoopRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
|
||||
def on_workflow_loop_next(self, event: LoopRunNextEvent) -> None:
|
||||
def on_workflow_loop_next(self, event: LoopRunNextEvent):
|
||||
"""
|
||||
Publish loop next
|
||||
"""
|
||||
@ -242,7 +240,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
self.print_text(f"Loop Index: {event.index}", color="blue")
|
||||
|
||||
def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent) -> None:
|
||||
def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent):
|
||||
"""
|
||||
Publish loop completed
|
||||
"""
|
||||
@ -252,7 +250,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
)
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
|
||||
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
|
||||
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n"):
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = self._get_colored_text(text, color) if color else text
|
||||
print(f"{text_to_print}", end=end)
|
||||
|
||||
@ -20,7 +20,7 @@ class ConversationVariableUpdater(Protocol):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, conversation_id: str, variable: "Variable") -> None:
|
||||
def update(self, conversation_id: str, variable: "Variable"):
|
||||
"""
|
||||
Updates the value of the specified conversation variable in the underlying storage.
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ class VariablePool(BaseModel):
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
def model_post_init(self, context: Any, /) -> None:
|
||||
def model_post_init(self, context: Any, /):
|
||||
# Create a mapping from field names to SystemVariableKey enum values
|
||||
self._add_system_variables(self.system_variables)
|
||||
# Add environment variables to the variable pool
|
||||
@ -57,7 +57,7 @@ class VariablePool(BaseModel):
|
||||
for var in self.conversation_variables:
|
||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
def add(self, selector: Sequence[str], value: Any, /):
|
||||
"""
|
||||
Add a variable to the variable pool.
|
||||
|
||||
@ -161,11 +161,11 @@ class VariablePool(BaseModel):
|
||||
# Return result as Segment
|
||||
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
|
||||
|
||||
def _extract_value(self, obj: Any) -> Any:
|
||||
def _extract_value(self, obj: Any):
|
||||
"""Extract the actual value from an ObjectSegment."""
|
||||
return obj.value if isinstance(obj, ObjectSegment) else obj
|
||||
|
||||
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Any:
|
||||
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str):
|
||||
"""Get a nested attribute from a dictionary-like object."""
|
||||
if not isinstance(obj, dict):
|
||||
return None
|
||||
|
||||
@ -112,7 +112,7 @@ class WorkflowNodeExecution(BaseModel):
|
||||
process_data: Optional[Mapping[str, Any]] = None,
|
||||
outputs: Optional[Mapping[str, Any]] = None,
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Update the model from mappings.
|
||||
|
||||
|
||||
@ -205,9 +205,7 @@ class Graph(BaseModel):
|
||||
return graph
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_node_ids(
|
||||
cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str
|
||||
) -> None:
|
||||
def _recursively_add_node_ids(cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str):
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
@ -225,7 +223,7 @@ class Graph(BaseModel):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]) -> None:
|
||||
def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]):
|
||||
"""
|
||||
Check whether it is connected to the previous node
|
||||
"""
|
||||
@ -256,7 +254,7 @@ class Graph(BaseModel):
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
node_parallel_mapping: dict[str, str],
|
||||
parent_parallel: Optional[GraphParallel] = None,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Recursively add parallel ids
|
||||
|
||||
@ -461,7 +459,7 @@ class Graph(BaseModel):
|
||||
level_limit: int,
|
||||
parent_parallel_id: str,
|
||||
current_level: int = 1,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Check if it exceeds N layers of parallel
|
||||
"""
|
||||
@ -488,7 +486,7 @@ class Graph(BaseModel):
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
merge_node_id: str,
|
||||
start_node_id: str,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
@ -614,7 +612,7 @@ class Graph(BaseModel):
|
||||
@classmethod
|
||||
def _recursively_fetch_routes(
|
||||
cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str]
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Recursively fetch route
|
||||
"""
|
||||
|
||||
@ -47,7 +47,7 @@ class RouteNodeState(BaseModel):
|
||||
|
||||
index: int = 1
|
||||
|
||||
def set_finished(self, run_result: NodeRunResult) -> None:
|
||||
def set_finished(self, run_result: NodeRunResult):
|
||||
"""
|
||||
Node finished
|
||||
|
||||
@ -94,7 +94,7 @@ class RuntimeRouteState(BaseModel):
|
||||
self.node_state_mapping[state.id] = state
|
||||
return state
|
||||
|
||||
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
|
||||
def add_route(self, source_node_state_id: str, target_node_state_id: str):
|
||||
"""
|
||||
Add route to the graph state
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
|
||||
initializer=None,
|
||||
initargs=(),
|
||||
max_submit_count=dify_config.MAX_SUBMIT_COUNT,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(max_workers, thread_name_prefix, initializer, initargs)
|
||||
self.max_submit_count = max_submit_count
|
||||
self.submit_count = 0
|
||||
@ -80,7 +80,7 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
|
||||
def task_done_callback(self, future):
|
||||
self.submit_count -= 1
|
||||
|
||||
def check_is_full(self) -> None:
|
||||
def check_is_full(self):
|
||||
if self.submit_count > self.max_submit_count:
|
||||
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
|
||||
|
||||
@ -104,7 +104,7 @@ class GraphEngine:
|
||||
max_execution_steps: int,
|
||||
max_execution_time: int,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
):
|
||||
thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT
|
||||
thread_pool_max_workers = 10
|
||||
|
||||
@ -537,7 +537,7 @@ class GraphEngine:
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None,
|
||||
handle_exceptions: list[str] = [],
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Run parallel nodes
|
||||
"""
|
||||
|
||||
@ -66,7 +66,7 @@ class AgentNode(BaseNode):
|
||||
_node_type = NodeType.AGENT
|
||||
_node_data: AgentNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AgentNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -22,7 +22,7 @@ class AnswerNode(BaseNode):
|
||||
|
||||
_node_data: AnswerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AnswerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -134,7 +134,7 @@ class AnswerStreamGeneratorRouter:
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
answer_dependencies: dict[str, list[str]],
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Recursive fetch answer dependencies
|
||||
:param current_node_id: current node id
|
||||
|
||||
@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnswerStreamProcessor(StreamProcessor):
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool):
|
||||
super().__init__(graph, variable_pool)
|
||||
self.generate_routes = graph.answer_stream_generate_routes
|
||||
self.route_position = {}
|
||||
@ -66,7 +66,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
else:
|
||||
yield event
|
||||
|
||||
def reset(self) -> None:
|
||||
def reset(self):
|
||||
self.route_position = {}
|
||||
for answer_node_id, _ in self.generate_routes.answer_generate_route.items():
|
||||
self.route_position[answer_node_id] = 0
|
||||
|
||||
@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamProcessor(ABC):
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool):
|
||||
self.graph = graph
|
||||
self.variable_pool = variable_pool
|
||||
self.rest_node_ids = graph.node_ids.copy()
|
||||
@ -20,7 +20,7 @@ class StreamProcessor(ABC):
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent):
|
||||
finished_node_id = event.route_node_state.node_id
|
||||
if finished_node_id not in self.rest_node_ids:
|
||||
return
|
||||
@ -89,7 +89,7 @@ class StreamProcessor(ABC):
|
||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
|
||||
return node_ids
|
||||
|
||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]):
|
||||
"""
|
||||
remove target node ids until merge
|
||||
"""
|
||||
|
||||
@ -28,7 +28,7 @@ class DefaultValue(BaseModel):
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str) -> Any:
|
||||
def _parse_json(value: str):
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
|
||||
@ -28,7 +28,7 @@ class BaseNode:
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
):
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
@ -51,7 +51,7 @@ class BaseNode:
|
||||
self.node_id = node_id
|
||||
|
||||
@abstractmethod
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
||||
def init_node_data(self, data: Mapping[str, Any]): ...
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
@ -141,7 +141,7 @@ class BaseNode:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
return {}
|
||||
|
||||
@property
|
||||
|
||||
@ -28,7 +28,7 @@ class CodeNode(BaseNode):
|
||||
|
||||
_node_data: CodeNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = CodeNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -50,7 +50,7 @@ class CodeNode(BaseNode):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
|
||||
@ -47,7 +47,7 @@ class DocumentExtractorNode(BaseNode):
|
||||
|
||||
_node_data: DocumentExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = DocumentExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -14,7 +14,7 @@ class EndNode(BaseNode):
|
||||
|
||||
_node_data: EndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = EndNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -121,7 +121,7 @@ class EndStreamGeneratorRouter:
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
end_dependencies: dict[str, list[str]],
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Recursive fetch end dependencies
|
||||
:param current_node_id: current node id
|
||||
|
||||
@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndStreamProcessor(StreamProcessor):
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool):
|
||||
super().__init__(graph, variable_pool)
|
||||
self.end_stream_param = graph.end_stream_param
|
||||
self.route_position = {}
|
||||
@ -76,7 +76,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
else:
|
||||
yield event
|
||||
|
||||
def reset(self) -> None:
|
||||
def reset(self):
|
||||
self.route_position = {}
|
||||
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
|
||||
self.route_position[end_node_id] = 0
|
||||
|
||||
@ -38,7 +38,7 @@ class HttpRequestNode(BaseNode):
|
||||
|
||||
_node_data: HttpRequestNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = HttpRequestNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -60,7 +60,7 @@ class HttpRequestNode(BaseNode):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict[str, Any]] = None):
|
||||
return {
|
||||
"type": "http-request",
|
||||
"config": {
|
||||
|
||||
@ -19,7 +19,7 @@ class IfElseNode(BaseNode):
|
||||
|
||||
_node_data: IfElseNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IfElseNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -67,7 +67,7 @@ class IterationNode(BaseNode):
|
||||
|
||||
_node_data: IterationNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -89,7 +89,7 @@ class IterationNode(BaseNode):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
return {
|
||||
"type": "iteration",
|
||||
"config": {
|
||||
|
||||
@ -18,7 +18,7 @@ class IterationStartNode(BaseNode):
|
||||
|
||||
_node_data: IterationStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationStartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -105,7 +105,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
thread_pool_id: Optional[str] = None,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
@ -125,7 +125,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -41,7 +41,7 @@ class ListOperatorNode(BaseNode):
|
||||
|
||||
_node_data: ListOperatorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ListOperatorNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -41,5 +41,5 @@ class FileTypeNotSupportError(LLMNodeError):
|
||||
|
||||
|
||||
class UnsupportedPromptContentTypeError(LLMNodeError):
|
||||
def __init__(self, *, type_name: str) -> None:
|
||||
def __init__(self, *, type_name: str):
|
||||
super().__init__(f"Prompt content type {type_name} is not supported.")
|
||||
|
||||
@ -107,7 +107,7 @@ def fetch_memory(
|
||||
return memory
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
|
||||
@ -120,7 +120,7 @@ class LLMNode(BaseNode):
|
||||
thread_pool_id: Optional[str] = None,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
@ -140,7 +140,7 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LLMNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -951,7 +951,7 @@ class LLMNode(BaseNode):
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
return {
|
||||
"type": "llm",
|
||||
"config": {
|
||||
|
||||
@ -18,7 +18,7 @@ class LoopEndNode(BaseNode):
|
||||
|
||||
_node_data: LoopEndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopEndNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -54,7 +54,7 @@ class LoopNode(BaseNode):
|
||||
|
||||
_node_data: LoopNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -18,7 +18,7 @@ class LoopStartNode(BaseNode):
|
||||
|
||||
_node_data: LoopStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopStartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -98,7 +98,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
def set_reasoning_mode(cls, v) -> str:
|
||||
return v or "function_call"
|
||||
|
||||
def get_parameter_json_schema(self) -> dict:
|
||||
def get_parameter_json_schema(self):
|
||||
"""
|
||||
Get parameter json schema.
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ class InvalidValueTypeError(ParameterExtractorNodeError):
|
||||
expected_type: SegmentType,
|
||||
actual_type: SegmentType | None,
|
||||
value: Any,
|
||||
) -> None:
|
||||
):
|
||||
message = (
|
||||
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
|
||||
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
|
||||
|
||||
@ -94,7 +94,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
|
||||
_node_data: ParameterExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ParameterExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -119,7 +119,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
return {
|
||||
"model": {
|
||||
"prompt_templates": {
|
||||
@ -545,7 +545,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
def _validate_result(self, data: ParameterExtractorNodeData, result: dict):
|
||||
if len(data.parameters) != len(result):
|
||||
raise InvalidNumberOfParametersError("Invalid number of parameters")
|
||||
|
||||
@ -597,7 +597,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
|
||||
"""
|
||||
Transform result into standard format.
|
||||
"""
|
||||
@ -690,7 +690,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData):
|
||||
"""
|
||||
Generate default result.
|
||||
"""
|
||||
|
||||
@ -63,7 +63,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
thread_pool_id: Optional[str] = None,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
@ -83,7 +83,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = QuestionClassifierNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -275,7 +275,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
|
||||
@ -15,7 +15,7 @@ class StartNode(BaseNode):
|
||||
|
||||
_node_data: StartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = StartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -18,7 +18,7 @@ class TemplateTransformNode(BaseNode):
|
||||
|
||||
_node_data: TemplateTransformNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -40,7 +40,7 @@ class TemplateTransformNode(BaseNode):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
|
||||
@ -45,7 +45,7 @@ class ToolNode(BaseNode):
|
||||
|
||||
_node_data: ToolNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ToolNodeData.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -15,7 +15,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -11,7 +11,7 @@ from .exc import VariableOperatorNodeError
|
||||
class ConversationVariableUpdaterImpl:
|
||||
_engine: Engine | None
|
||||
|
||||
def __init__(self, engine: Engine | None = None) -> None:
|
||||
def __init__(self, engine: Engine | None = None):
|
||||
self._engine = engine
|
||||
|
||||
def _get_engine(self) -> Engine:
|
||||
|
||||
@ -30,7 +30,7 @@ class VariableAssignerNode(BaseNode):
|
||||
|
||||
_node_data: VariableAssignerData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -61,7 +61,7 @@ class VariableAssignerNode(BaseNode):
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
|
||||
@ -32,5 +32,5 @@ class ConversationIDNotFoundError(VariableOperatorNodeError):
|
||||
|
||||
|
||||
class InvalidDataError(VariableOperatorNodeError):
|
||||
def __init__(self, message: str) -> None:
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
@ -58,7 +58,7 @@ class VariableAssignerNode(BaseNode):
|
||||
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -16,7 +16,7 @@ class WorkflowExecutionRepository(Protocol):
|
||||
application domains or deployment scenarios.
|
||||
"""
|
||||
|
||||
def save(self, execution: WorkflowExecution) -> None:
|
||||
def save(self, execution: WorkflowExecution):
|
||||
"""
|
||||
Save or update a WorkflowExecution instance.
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ class WorkflowNodeExecutionRepository(Protocol):
|
||||
application domains or deployment scenarios.
|
||||
"""
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
def save(self, execution: WorkflowNodeExecution):
|
||||
"""
|
||||
Save or update a NodeExecution instance.
|
||||
|
||||
|
||||
@ -57,7 +57,7 @@ class VariableTemplateParser:
|
||||
self.template = template
|
||||
self.variable_keys = self.extract()
|
||||
|
||||
def extract(self) -> list:
|
||||
def extract(self):
|
||||
"""
|
||||
Extracts all the template variable keys from the template string.
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ class WorkflowCycleManager:
|
||||
workflow_info: CycleManagerWorkflowInfo,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
) -> None:
|
||||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_system_variables = workflow_system_variables
|
||||
self._workflow_info = workflow_info
|
||||
@ -299,7 +299,7 @@ class WorkflowCycleManager:
|
||||
error_message: Optional[str] = None,
|
||||
exceptions_count: int = 0,
|
||||
finished_at: Optional[datetime] = None,
|
||||
) -> None:
|
||||
):
|
||||
"""Update workflow execution with completion data."""
|
||||
execution.status = status
|
||||
execution.outputs = outputs or {}
|
||||
@ -316,7 +316,7 @@ class WorkflowCycleManager:
|
||||
workflow_execution: WorkflowExecution,
|
||||
conversation_id: Optional[str],
|
||||
external_trace_id: Optional[str],
|
||||
) -> None:
|
||||
):
|
||||
"""Add trace task if trace manager is provided."""
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
@ -334,7 +334,7 @@ class WorkflowCycleManager:
|
||||
workflow_execution_id: str,
|
||||
error_message: str,
|
||||
now: datetime,
|
||||
) -> None:
|
||||
):
|
||||
"""Fail all running node executions for a workflow."""
|
||||
running_node_executions = [
|
||||
node_exec
|
||||
@ -406,7 +406,7 @@ class WorkflowCycleManager:
|
||||
status: WorkflowNodeExecutionStatus,
|
||||
error: Optional[str] = None,
|
||||
handle_special_values: bool = False,
|
||||
) -> None:
|
||||
):
|
||||
"""Update node execution with completion data."""
|
||||
finished_at = naive_utc_now()
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
@ -48,7 +48,7 @@ class WorkflowEntry:
|
||||
call_depth: int,
|
||||
variable_pool: VariablePool,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Init workflow entry
|
||||
:param tenant_id: tenant id
|
||||
@ -320,7 +320,7 @@ class WorkflowEntry:
|
||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
||||
|
||||
@staticmethod
|
||||
def _handle_special_values(value: Any) -> Any:
|
||||
def _handle_special_values(value: Any):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
@ -345,7 +345,7 @@ class WorkflowEntry:
|
||||
user_inputs: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
):
|
||||
# NOTE(QuantumGhost): This logic should remain synchronized with
|
||||
# the implementation of `load_into_variable_pool`, specifically the logic about
|
||||
# variable existence checking.
|
||||
|
||||
@ -13,7 +13,7 @@ class WorkflowRuntimeTypeConverter:
|
||||
result = self._to_json_encodable_recursive(value)
|
||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
||||
|
||||
def _to_json_encodable_recursive(self, value: Any) -> Any:
|
||||
def _to_json_encodable_recursive(self, value: Any):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, (bool, int, str, float)):
|
||||
|
||||
Reference in New Issue
Block a user