remove bare list, dict, Sequence, None, Any (#25058)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Asuka Minato
2025-09-06 04:32:23 +09:00
committed by GitHub
parent 2b0695bdde
commit a78339a040
306 changed files with 787 additions and 817 deletions

View File

@ -5,7 +5,7 @@ from core.workflow.graph_engine.entities.event import GraphEngineEvent
class WorkflowCallback(ABC):
@abstractmethod
def on_event(self, event: GraphEngineEvent) -> None:
def on_event(self, event: GraphEngineEvent):
"""
Published event
"""

View File

@ -36,10 +36,10 @@ _TEXT_COLOR_MAPPING = {
class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None:
def __init__(self):
self.current_node_id: Optional[str] = None
def on_event(self, event: GraphEngineEvent) -> None:
def on_event(self, event: GraphEngineEvent):
if isinstance(event, GraphRunStartedEvent):
self.print_text("\n[GraphRunStartedEvent]", color="pink")
elif isinstance(event, GraphRunSucceededEvent):
@ -75,7 +75,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
else:
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent):
"""
Workflow node execute started
"""
@ -84,7 +84,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
self.print_text(f"Type: {event.node_type.value}", color="yellow")
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent):
"""
Workflow node execute succeeded
"""
@ -115,7 +115,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
color="green",
)
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent):
"""
Workflow node execute failed
"""
@ -143,7 +143,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
color="red",
)
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent):
"""
Publish text chunk
"""
@ -161,7 +161,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
self.print_text(event.chunk_content, color="pink", end="")
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent):
"""
Publish parallel started
"""
@ -173,9 +173,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
if event.in_loop_id:
self.print_text(f"Loop ID: {event.in_loop_id}", color="blue")
def on_workflow_parallel_completed(
self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
) -> None:
def on_workflow_parallel_completed(self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
"""
Publish parallel completed
"""
@ -200,14 +198,14 @@ class WorkflowLoggingCallback(WorkflowCallback):
if isinstance(event, ParallelBranchRunFailedEvent):
self.print_text(f"Error: {event.error}", color=color)
def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
def on_workflow_iteration_started(self, event: IterationRunStartedEvent):
"""
Publish iteration started
"""
self.print_text("\n[IterationRunStartedEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
def on_workflow_iteration_next(self, event: IterationRunNextEvent):
"""
Publish iteration next
"""
@ -215,7 +213,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
self.print_text(f"Iteration Index: {event.index}", color="blue")
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent):
"""
Publish iteration completed
"""
@ -227,14 +225,14 @@ class WorkflowLoggingCallback(WorkflowCallback):
)
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
def on_workflow_loop_started(self, event: LoopRunStartedEvent) -> None:
def on_workflow_loop_started(self, event: LoopRunStartedEvent):
"""
Publish loop started
"""
self.print_text("\n[LoopRunStartedEvent]", color="blue")
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
def on_workflow_loop_next(self, event: LoopRunNextEvent) -> None:
def on_workflow_loop_next(self, event: LoopRunNextEvent):
"""
Publish loop next
"""
@ -242,7 +240,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
self.print_text(f"Loop Index: {event.index}", color="blue")
def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent) -> None:
def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent):
"""
Publish loop completed
"""
@ -252,7 +250,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
)
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n"):
"""Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text
print(f"{text_to_print}", end=end)

View File

@ -20,7 +20,7 @@ class ConversationVariableUpdater(Protocol):
"""
@abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable") -> None:
def update(self, conversation_id: str, variable: "Variable"):
"""
Updates the value of the specified conversation variable in the underlying storage.

View File

@ -47,7 +47,7 @@ class VariablePool(BaseModel):
default_factory=list,
)
def model_post_init(self, context: Any, /) -> None:
def model_post_init(self, context: Any, /):
# Create a mapping from field names to SystemVariableKey enum values
self._add_system_variables(self.system_variables)
# Add environment variables to the variable pool
@ -57,7 +57,7 @@ class VariablePool(BaseModel):
for var in self.conversation_variables:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
def add(self, selector: Sequence[str], value: Any, /) -> None:
def add(self, selector: Sequence[str], value: Any, /):
"""
Add a variable to the variable pool.
@ -161,11 +161,11 @@ class VariablePool(BaseModel):
# Return result as Segment
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
def _extract_value(self, obj: Any) -> Any:
def _extract_value(self, obj: Any):
"""Extract the actual value from an ObjectSegment."""
return obj.value if isinstance(obj, ObjectSegment) else obj
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Any:
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str):
"""Get a nested attribute from a dictionary-like object."""
if not isinstance(obj, dict):
return None

View File

@ -112,7 +112,7 @@ class WorkflowNodeExecution(BaseModel):
process_data: Optional[Mapping[str, Any]] = None,
outputs: Optional[Mapping[str, Any]] = None,
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None,
) -> None:
):
"""
Update the model from mappings.

View File

@ -205,9 +205,7 @@ class Graph(BaseModel):
return graph
@classmethod
def _recursively_add_node_ids(
cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str
) -> None:
def _recursively_add_node_ids(cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str):
"""
Recursively add node ids
@ -225,7 +223,7 @@ class Graph(BaseModel):
)
@classmethod
def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]) -> None:
def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]):
"""
Check whether it is connected to the previous node
"""
@ -256,7 +254,7 @@ class Graph(BaseModel):
parallel_mapping: dict[str, GraphParallel],
node_parallel_mapping: dict[str, str],
parent_parallel: Optional[GraphParallel] = None,
) -> None:
):
"""
Recursively add parallel ids
@ -461,7 +459,7 @@ class Graph(BaseModel):
level_limit: int,
parent_parallel_id: str,
current_level: int = 1,
) -> None:
):
"""
Check if it exceeds N layers of parallel
"""
@ -488,7 +486,7 @@ class Graph(BaseModel):
edge_mapping: dict[str, list[GraphEdge]],
merge_node_id: str,
start_node_id: str,
) -> None:
):
"""
Recursively add node ids
@ -614,7 +612,7 @@ class Graph(BaseModel):
@classmethod
def _recursively_fetch_routes(
cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str]
) -> None:
):
"""
Recursively fetch route
"""

View File

@ -47,7 +47,7 @@ class RouteNodeState(BaseModel):
index: int = 1
def set_finished(self, run_result: NodeRunResult) -> None:
def set_finished(self, run_result: NodeRunResult):
"""
Node finished
@ -94,7 +94,7 @@ class RuntimeRouteState(BaseModel):
self.node_state_mapping[state.id] = state
return state
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
def add_route(self, source_node_state_id: str, target_node_state_id: str):
"""
Add route to the graph state

View File

@ -66,7 +66,7 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
initializer=None,
initargs=(),
max_submit_count=dify_config.MAX_SUBMIT_COUNT,
) -> None:
):
super().__init__(max_workers, thread_name_prefix, initializer, initargs)
self.max_submit_count = max_submit_count
self.submit_count = 0
@ -80,7 +80,7 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
def task_done_callback(self, future):
self.submit_count -= 1
def check_is_full(self) -> None:
def check_is_full(self):
if self.submit_count > self.max_submit_count:
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
@ -104,7 +104,7 @@ class GraphEngine:
max_execution_steps: int,
max_execution_time: int,
thread_pool_id: Optional[str] = None,
) -> None:
):
thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT
thread_pool_max_workers = 10
@ -537,7 +537,7 @@ class GraphEngine:
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> None:
):
"""
Run parallel nodes
"""

View File

@ -66,7 +66,7 @@ class AgentNode(BaseNode):
_node_type = NodeType.AGENT
_node_data: AgentNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AgentNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -22,7 +22,7 @@ class AnswerNode(BaseNode):
_node_data: AnswerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AnswerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -134,7 +134,7 @@ class AnswerStreamGeneratorRouter:
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]],
) -> None:
):
"""
Recursive fetch answer dependencies
:param current_node_id: current node id

View File

@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
class AnswerStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
def __init__(self, graph: Graph, variable_pool: VariablePool):
super().__init__(graph, variable_pool)
self.generate_routes = graph.answer_stream_generate_routes
self.route_position = {}
@ -66,7 +66,7 @@ class AnswerStreamProcessor(StreamProcessor):
else:
yield event
def reset(self) -> None:
def reset(self):
self.route_position = {}
for answer_node_id, _ in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0

View File

@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
def __init__(self, graph: Graph, variable_pool: VariablePool):
self.graph = graph
self.variable_pool = variable_pool
self.rest_node_ids = graph.node_ids.copy()
@ -20,7 +20,7 @@ class StreamProcessor(ABC):
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent):
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
@ -89,7 +89,7 @@ class StreamProcessor(ABC):
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]):
"""
remove target node ids until merge
"""

View File

@ -28,7 +28,7 @@ class DefaultValue(BaseModel):
key: str
@staticmethod
def _parse_json(value: str) -> Any:
def _parse_json(value: str):
"""Unified JSON parsing handler"""
try:
return json.loads(value)

View File

@ -28,7 +28,7 @@ class BaseNode:
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
) -> None:
):
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
@ -51,7 +51,7 @@ class BaseNode:
self.node_id = node_id
@abstractmethod
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
def init_node_data(self, data: Mapping[str, Any]): ...
@abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
@ -141,7 +141,7 @@ class BaseNode:
return {}
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Optional[dict] = None):
return {}
@property

View File

@ -28,7 +28,7 @@ class CodeNode(BaseNode):
_node_data: CodeNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = CodeNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -50,7 +50,7 @@ class CodeNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Optional[dict] = None):
"""
Get default config of node.
:param filters: filter by node config parameters.

View File

@ -47,7 +47,7 @@ class DocumentExtractorNode(BaseNode):
_node_data: DocumentExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = DocumentExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -14,7 +14,7 @@ class EndNode(BaseNode):
_node_data: EndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = EndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -121,7 +121,7 @@ class EndStreamGeneratorRouter:
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
end_dependencies: dict[str, list[str]],
) -> None:
):
"""
Recursive fetch end dependencies
:param current_node_id: current node id

View File

@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
class EndStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
def __init__(self, graph: Graph, variable_pool: VariablePool):
super().__init__(graph, variable_pool)
self.end_stream_param = graph.end_stream_param
self.route_position = {}
@ -76,7 +76,7 @@ class EndStreamProcessor(StreamProcessor):
else:
yield event
def reset(self) -> None:
def reset(self):
self.route_position = {}
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
self.route_position[end_node_id] = 0

View File

@ -38,7 +38,7 @@ class HttpRequestNode(BaseNode):
_node_data: HttpRequestNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = HttpRequestNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -60,7 +60,7 @@ class HttpRequestNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
def get_default_config(cls, filters: Optional[dict[str, Any]] = None):
return {
"type": "http-request",
"config": {

View File

@ -19,7 +19,7 @@ class IfElseNode(BaseNode):
_node_data: IfElseNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IfElseNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -67,7 +67,7 @@ class IterationNode(BaseNode):
_node_data: IterationNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -89,7 +89,7 @@ class IterationNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Optional[dict] = None):
return {
"type": "iteration",
"config": {

View File

@ -18,7 +18,7 @@ class IterationStartNode(BaseNode):
_node_data: IterationStartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -105,7 +105,7 @@ class KnowledgeRetrievalNode(BaseNode):
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
):
super().__init__(
id=id,
config=config,
@ -125,7 +125,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -41,7 +41,7 @@ class ListOperatorNode(BaseNode):
_node_data: ListOperatorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ListOperatorNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -41,5 +41,5 @@ class FileTypeNotSupportError(LLMNodeError):
class UnsupportedPromptContentTypeError(LLMNodeError):
def __init__(self, *, type_name: str) -> None:
def __init__(self, *, type_name: str):
super().__init__(f"Prompt content type {type_name} is not supported.")

View File

@ -107,7 +107,7 @@ def fetch_memory(
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration

View File

@ -120,7 +120,7 @@ class LLMNode(BaseNode):
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
):
super().__init__(
id=id,
config=config,
@ -140,7 +140,7 @@ class LLMNode(BaseNode):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LLMNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -951,7 +951,7 @@ class LLMNode(BaseNode):
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Optional[dict] = None):
return {
"type": "llm",
"config": {

View File

@ -18,7 +18,7 @@ class LoopEndNode(BaseNode):
_node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopEndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -54,7 +54,7 @@ class LoopNode(BaseNode):
_node_data: LoopNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -18,7 +18,7 @@ class LoopStartNode(BaseNode):
_node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -98,7 +98,7 @@ class ParameterExtractorNodeData(BaseNodeData):
def set_reasoning_mode(cls, v) -> str:
return v or "function_call"
def get_parameter_json_schema(self) -> dict:
def get_parameter_json_schema(self):
"""
Get parameter json schema.

View File

@ -63,7 +63,7 @@ class InvalidValueTypeError(ParameterExtractorNodeError):
expected_type: SegmentType,
actual_type: SegmentType | None,
value: Any,
) -> None:
):
message = (
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"

View File

@ -94,7 +94,7 @@ class ParameterExtractorNode(BaseNode):
_node_data: ParameterExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ParameterExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -119,7 +119,7 @@ class ParameterExtractorNode(BaseNode):
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Optional[dict] = None):
return {
"model": {
"prompt_templates": {
@ -545,7 +545,7 @@ class ParameterExtractorNode(BaseNode):
return prompt_messages
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
def _validate_result(self, data: ParameterExtractorNodeData, result: dict):
if len(data.parameters) != len(result):
raise InvalidNumberOfParametersError("Invalid number of parameters")
@ -597,7 +597,7 @@ class ParameterExtractorNode(BaseNode):
except ValueError:
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
"""
Transform result into standard format.
"""
@ -690,7 +690,7 @@ class ParameterExtractorNode(BaseNode):
logger.info("extra error: %s", result)
return None
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
def _generate_default_result(self, data: ParameterExtractorNodeData):
"""
Generate default result.
"""

View File

@ -63,7 +63,7 @@ class QuestionClassifierNode(BaseNode):
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
):
super().__init__(
id=id,
config=config,
@ -83,7 +83,7 @@ class QuestionClassifierNode(BaseNode):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = QuestionClassifierNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -275,7 +275,7 @@ class QuestionClassifierNode(BaseNode):
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Optional[dict] = None):
"""
Get default config of node.
:param filters: filter by node config parameters.

View File

@ -15,7 +15,7 @@ class StartNode(BaseNode):
_node_data: StartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = StartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -18,7 +18,7 @@ class TemplateTransformNode(BaseNode):
_node_data: TemplateTransformNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = TemplateTransformNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -40,7 +40,7 @@ class TemplateTransformNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Optional[dict] = None):
"""
Get default config of node.
:param filters: filter by node config parameters.

View File

@ -45,7 +45,7 @@ class ToolNode(BaseNode):
_node_data: ToolNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ToolNodeData.model_validate(data)
@classmethod

View File

@ -15,7 +15,7 @@ class VariableAggregatorNode(BaseNode):
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -11,7 +11,7 @@ from .exc import VariableOperatorNodeError
class ConversationVariableUpdaterImpl:
_engine: Engine | None
def __init__(self, engine: Engine | None = None) -> None:
def __init__(self, engine: Engine | None = None):
self._engine = engine
def _get_engine(self) -> Engine:

View File

@ -30,7 +30,7 @@ class VariableAssignerNode(BaseNode):
_node_data: VariableAssignerData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -61,7 +61,7 @@ class VariableAssignerNode(BaseNode):
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
) -> None:
):
super().__init__(
id=id,
config=config,

View File

@ -32,5 +32,5 @@ class ConversationIDNotFoundError(VariableOperatorNodeError):
class InvalidDataError(VariableOperatorNodeError):
def __init__(self, message: str) -> None:
def __init__(self, message: str):
super().__init__(message)

View File

@ -58,7 +58,7 @@ class VariableAssignerNode(BaseNode):
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -16,7 +16,7 @@ class WorkflowExecutionRepository(Protocol):
application domains or deployment scenarios.
"""
def save(self, execution: WorkflowExecution) -> None:
def save(self, execution: WorkflowExecution):
"""
Save or update a WorkflowExecution instance.

View File

@ -26,7 +26,7 @@ class WorkflowNodeExecutionRepository(Protocol):
application domains or deployment scenarios.
"""
def save(self, execution: WorkflowNodeExecution) -> None:
def save(self, execution: WorkflowNodeExecution):
"""
Save or update a NodeExecution instance.

View File

@ -57,7 +57,7 @@ class VariableTemplateParser:
self.template = template
self.variable_keys = self.extract()
def extract(self) -> list:
def extract(self):
"""
Extracts all the template variable keys from the template string.

View File

@ -48,7 +48,7 @@ class WorkflowCycleManager:
workflow_info: CycleManagerWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
):
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
self._workflow_info = workflow_info
@ -299,7 +299,7 @@ class WorkflowCycleManager:
error_message: Optional[str] = None,
exceptions_count: int = 0,
finished_at: Optional[datetime] = None,
) -> None:
):
"""Update workflow execution with completion data."""
execution.status = status
execution.outputs = outputs or {}
@ -316,7 +316,7 @@ class WorkflowCycleManager:
workflow_execution: WorkflowExecution,
conversation_id: Optional[str],
external_trace_id: Optional[str],
) -> None:
):
"""Add trace task if trace manager is provided."""
if trace_manager:
trace_manager.add_trace_task(
@ -334,7 +334,7 @@ class WorkflowCycleManager:
workflow_execution_id: str,
error_message: str,
now: datetime,
) -> None:
):
"""Fail all running node executions for a workflow."""
running_node_executions = [
node_exec
@ -406,7 +406,7 @@ class WorkflowCycleManager:
status: WorkflowNodeExecutionStatus,
error: Optional[str] = None,
handle_special_values: bool = False,
) -> None:
):
"""Update node execution with completion data."""
finished_at = naive_utc_now()
elapsed_time = (finished_at - event.start_at).total_seconds()

View File

@ -48,7 +48,7 @@ class WorkflowEntry:
call_depth: int,
variable_pool: VariablePool,
thread_pool_id: Optional[str] = None,
) -> None:
):
"""
Init workflow entry
:param tenant_id: tenant id
@ -320,7 +320,7 @@ class WorkflowEntry:
return result if isinstance(result, Mapping) or result is None else dict(result)
@staticmethod
def _handle_special_values(value: Any) -> Any:
def _handle_special_values(value: Any):
if value is None:
return value
if isinstance(value, dict):
@ -345,7 +345,7 @@ class WorkflowEntry:
user_inputs: Mapping[str, Any],
variable_pool: VariablePool,
tenant_id: str,
) -> None:
):
# NOTE(QuantumGhost): This logic should remain synchronized with
# the implementation of `load_into_variable_pool`, specifically the logic about
# variable existence checking.

View File

@ -13,7 +13,7 @@ class WorkflowRuntimeTypeConverter:
result = self._to_json_encodable_recursive(value)
return result if isinstance(result, Mapping) or result is None else dict(result)
def _to_json_encodable_recursive(self, value: Any) -> Any:
def _to_json_encodable_recursive(self, value: Any):
if value is None:
return value
if isinstance(value, (bool, int, str, float)):