chore: add ast-grep rule to convert Optional[T] to T | None (#25560)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
-LAN-
2025-09-15 13:06:33 +08:00
committed by GitHub
parent 2e44ebe98d
commit bab4975809
394 changed files with 2555 additions and 2792 deletions

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from packaging.version import Version
from pydantic import ValidationError
@ -69,7 +69,7 @@ class AgentNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AgentNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -78,7 +78,7 @@ class AgentNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -401,7 +401,7 @@ class AgentNode(BaseNode):
icon = None
return icon
def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value]

View File

@ -1,6 +1,3 @@
from typing import Optional
class AgentNodeError(Exception):
"""Base exception for all agent node errors."""
@ -12,7 +9,7 @@ class AgentNodeError(Exception):
class AgentStrategyError(AgentNodeError):
"""Exception raised when there's an error with the agent strategy."""
def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None):
self.strategy_name = strategy_name
self.provider_name = provider_name
super().__init__(message)
@ -21,7 +18,7 @@ class AgentStrategyError(AgentNodeError):
class AgentStrategyNotFoundError(AgentStrategyError):
"""Exception raised when the specified agent strategy is not found."""
def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
def __init__(self, strategy_name: str, provider_name: str | None = None):
super().__init__(
f"Agent strategy '{strategy_name}' not found"
+ (f" for provider '{provider_name}'" if provider_name else ""),
@ -33,7 +30,7 @@ class AgentStrategyNotFoundError(AgentStrategyError):
class AgentInvocationError(AgentNodeError):
"""Exception raised when there's an error invoking the agent."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
@ -41,7 +38,7 @@ class AgentInvocationError(AgentNodeError):
class AgentParameterError(AgentNodeError):
"""Exception raised when there's an error with agent parameters."""
def __init__(self, message: str, parameter_name: Optional[str] = None):
def __init__(self, message: str, parameter_name: str | None = None):
self.parameter_name = parameter_name
super().__init__(message)
@ -49,7 +46,7 @@ class AgentParameterError(AgentNodeError):
class AgentVariableError(AgentNodeError):
"""Exception raised when there's an error with variables in the agent node."""
def __init__(self, message: str, variable_name: Optional[str] = None):
def __init__(self, message: str, variable_name: str | None = None):
self.variable_name = variable_name
super().__init__(message)
@ -71,7 +68,7 @@ class AgentInputTypeError(AgentNodeError):
class ToolFileError(AgentNodeError):
"""Exception raised when there's an error with a tool file."""
def __init__(self, message: str, file_id: Optional[str] = None):
def __init__(self, message: str, file_id: str | None = None):
self.file_id = file_id
super().__init__(message)
@ -86,7 +83,7 @@ class ToolFileNotFoundError(ToolFileError):
class AgentMessageTransformError(AgentNodeError):
"""Exception raised when there's an error transforming agent messages."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
@ -94,7 +91,7 @@ class AgentMessageTransformError(AgentNodeError):
class AgentModelError(AgentNodeError):
"""Exception raised when there's an error with the model used by the agent."""
def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
def __init__(self, message: str, model_name: str | None = None, provider: str | None = None):
self.model_name = model_name
self.provider = provider
super().__init__(message)
@ -103,7 +100,7 @@ class AgentModelError(AgentNodeError):
class AgentMemoryError(AgentNodeError):
"""Exception raised when there's an error with the agent's memory."""
def __init__(self, message: str, conversation_id: Optional[str] = None):
def __init__(self, message: str, conversation_id: str | None = None):
self.conversation_id = conversation_id
super().__init__(message)
@ -114,9 +111,9 @@ class AgentVariableTypeError(AgentNodeError):
def __init__(
self,
message: str,
variable_name: Optional[str] = None,
expected_type: Optional[str] = None,
actual_type: Optional[str] = None,
variable_name: str | None = None,
expected_type: str | None = None,
actual_type: str | None = None,
):
self.variable_name = variable_name
self.expected_type = expected_type

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from core.variables import ArrayFileSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
@ -25,7 +25,7 @@ class AnswerNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AnswerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -34,7 +34,7 @@ class AnswerNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,7 +1,6 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Optional
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
@ -72,7 +71,7 @@ class StreamProcessor(ABC):
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: str | None = None) -> list[str]:
if node_id not in self.rest_node_ids:
self.rest_node_ids.append(node_id)
node_ids = []

View File

@ -1,7 +1,7 @@
import json
from abc import ABC
from enum import StrEnum
from typing import Any, Optional, Union
from typing import Any, Union
from pydantic import BaseModel, model_validator
@ -121,10 +121,10 @@ class RetryConfig(BaseModel):
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
desc: str | None = None
version: str = "1"
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
error_strategy: ErrorStrategy | None = None
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = RetryConfig()
@property
@ -135,7 +135,7 @@ class BaseNodeData(ABC, BaseModel):
class BaseIterationNodeData(BaseNodeData):
start_node_id: Optional[str] = None
start_node_id: str | None = None
class BaseIterationState(BaseModel):
@ -150,7 +150,7 @@ class BaseIterationState(BaseModel):
class BaseLoopNodeData(BaseNodeData):
start_node_id: Optional[str] = None
start_node_id: str | None = None
class BaseLoopState(BaseModel):

View File

@ -1,7 +1,7 @@
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
from typing import TYPE_CHECKING, Any, ClassVar, Union
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -26,8 +26,8 @@ class BaseNode:
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
previous_node_id: str | None = None,
thread_pool_id: str | None = None,
):
self.id = id
self.tenant_id = graph_init_params.tenant_id
@ -141,7 +141,7 @@ class BaseNode:
return {}
@classmethod
def get_default_config(cls, filters: Optional[dict] = None):
def get_default_config(cls, filters: dict | None = None):
return {}
@property
@ -170,7 +170,7 @@ class BaseNode:
# to BaseNodeData properties in a type-safe way
@abstractmethod
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
...
@ -185,7 +185,7 @@ class BaseNode:
...
@abstractmethod
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
"""Get the node description."""
...
@ -201,7 +201,7 @@ class BaseNode:
# Public interface properties that delegate to abstract methods
@property
def error_strategy(self) -> Optional[ErrorStrategy]:
def error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
return self._get_error_strategy()
@ -216,7 +216,7 @@ class BaseNode:
return self._get_title()
@property
def description(self) -> Optional[str]:
def description(self) -> str | None:
"""Get the node description."""
return self._get_description()

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
from typing import Any, Optional
from typing import Any
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
@ -31,7 +31,7 @@ class CodeNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = CodeNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -40,7 +40,7 @@ class CodeNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -50,7 +50,7 @@ class CodeNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None):
def get_default_config(cls, filters: dict | None = None):
"""
Get default config of node.
:param filters: filter by node config parameters.
@ -161,7 +161,7 @@ class CodeNode(BaseNode):
def _transform_result(
self,
result: Mapping[str, Any],
output_schema: Optional[dict[str, CodeNodeData.Output]],
output_schema: dict[str, CodeNodeData.Output] | None,
prefix: str = "",
depth: int = 1,
):

View File

@ -1,4 +1,4 @@
from typing import Annotated, Literal, Optional
from typing import Annotated, Literal
from pydantic import AfterValidator, BaseModel
@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
class Output(BaseModel):
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: Optional[dict[str, "CodeNodeData.Output"]] = None
children: dict[str, "CodeNodeData.Output"] | None = None
class Dependency(BaseModel):
name: str
@ -44,4 +44,4 @@ class CodeNodeData(BaseNodeData):
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
code: str
outputs: dict[str, Output]
dependencies: Optional[list[Dependency]] = None
dependencies: list[Dependency] | None = None

View File

@ -5,7 +5,7 @@ import logging
import os
import tempfile
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from typing import Any
import chardet
import docx
@ -50,7 +50,7 @@ class DocumentExtractorNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = DocumentExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -59,7 +59,7 @@ class DocumentExtractorNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -17,7 +17,7 @@ class EndNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = EndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -26,7 +26,7 @@ class EndNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,7 +1,7 @@
import mimetypes
from collections.abc import Sequence
from email.message import Message
from typing import Any, Literal, Optional
from typing import Any, Literal
import httpx
from pydantic import BaseModel, Field, ValidationInfo, field_validator
@ -18,7 +18,7 @@ class HttpRequestNodeAuthorizationConfig(BaseModel):
class HttpRequestNodeAuthorization(BaseModel):
type: Literal["no-auth", "api-key"]
config: Optional[HttpRequestNodeAuthorizationConfig] = None
config: HttpRequestNodeAuthorizationConfig | None = None
@field_validator("config", mode="before")
@classmethod
@ -88,9 +88,9 @@ class HttpRequestNodeData(BaseNodeData):
authorization: HttpRequestNodeAuthorization
headers: str
params: str
body: Optional[HttpRequestNodeBody] = None
timeout: Optional[HttpRequestNodeTimeout] = None
ssl_verify: Optional[bool] = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
body: HttpRequestNodeBody | None = None
timeout: HttpRequestNodeTimeout | None = None
ssl_verify: bool | None = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
class Response:
@ -183,7 +183,7 @@ class Response:
return f"{(self.size / 1024 / 1024):.2f} MB"
@property
def parsed_content_disposition(self) -> Optional[Message]:
def parsed_content_disposition(self) -> Message | None:
content_disposition = self.headers.get("content-disposition", "")
if content_disposition:
msg = Message()

View File

@ -1,7 +1,7 @@
import logging
import mimetypes
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from typing import Any
from configs import dify_config
from core.file import File, FileTransferMethod
@ -41,7 +41,7 @@ class HttpRequestNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = HttpRequestNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -50,7 +50,7 @@ class HttpRequestNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -60,7 +60,7 @@ class HttpRequestNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict[str, Any]] = None):
def get_default_config(cls, filters: dict[str, Any] | None = None):
return {
"type": "http-request",
"config": {

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel, Field
@ -20,7 +20,7 @@ class IfElseNodeData(BaseNodeData):
logical_operator: Literal["and", "or"]
conditions: list[Condition]
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
cases: Optional[list[Case]] = None
cases: list[Case] | None = None

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal, Optional
from typing import Any, Literal
from typing_extensions import deprecated
@ -22,7 +22,7 @@ class IfElseNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IfElseNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -31,7 +31,7 @@ class IfElseNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,5 @@
from enum import StrEnum
from typing import Any, Optional
from typing import Any
from pydantic import Field
@ -17,7 +17,7 @@ class IterationNodeData(BaseIterationNodeData):
Iteration Node Data.
"""
parent_loop_id: Optional[str] = None # redundant field, not used currently
parent_loop_id: str | None = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
is_parallel: bool = False # open the parallel mode or not
@ -39,7 +39,7 @@ class IterationState(BaseIterationState):
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Optional[Any] = None
current_output: Any | None = None
class MetaData(BaseIterationState.MetaData):
"""
@ -48,7 +48,7 @@ class IterationState(BaseIterationState):
iterator_length: int
def get_last_output(self) -> Optional[Any]:
def get_last_output(self) -> Any | None:
"""
Get last output.
"""
@ -56,7 +56,7 @@ class IterationState(BaseIterationState):
return self.outputs[-1]
return None
def get_current_output(self) -> Optional[Any]:
def get_current_output(self) -> Any | None:
"""
Get current output.
"""

View File

@ -6,7 +6,7 @@ from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait
from datetime import datetime
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, cast
from flask import Flask, current_app
@ -70,7 +70,7 @@ class IterationNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -79,7 +79,7 @@ class IterationNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -89,7 +89,7 @@ class IterationNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None):
def get_default_config(cls, filters: dict | None = None):
return {
"type": "iteration",
"config": {
@ -424,7 +424,7 @@ class IterationNode(BaseNode):
graph_engine: "GraphEngine",
iteration_graph: Graph,
iter_run_map: dict[str, float],
parallel_mode_run_id: Optional[str] = None,
parallel_mode_run_id: str | None = None,
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -21,7 +21,7 @@ class IterationStartNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -30,7 +30,7 @@ class IterationStartNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel, Field
@ -49,11 +49,11 @@ class MultipleRetrievalConfig(BaseModel):
"""
top_k: int
score_threshold: Optional[float] = None
score_threshold: float | None = None
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: Optional[RerankingModelConfig] = None
weights: Optional[WeightedScoreConfig] = None
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class SingleRetrievalConfig(BaseModel):
@ -104,8 +104,8 @@ class MetadataFilteringCondition(BaseModel):
Metadata Filtering Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class KnowledgeRetrievalNodeData(BaseNodeData):
@ -117,11 +117,11 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
query_variable_selector: list[str]
dataset_ids: list[str]
retrieval_mode: Literal["single", "multiple"]
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
single_retrieval_config: Optional[SingleRetrievalConfig] = None
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
multiple_retrieval_config: MultipleRetrievalConfig | None = None
single_retrieval_config: SingleRetrievalConfig | None = None
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
metadata_model_config: ModelConfig | None = None
metadata_filtering_conditions: MetadataFilteringCondition | None = None
vision: VisionConfig = Field(default_factory=VisionConfig)
@property

View File

@ -4,7 +4,7 @@ import re
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import Float, and_, func, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
@ -101,8 +101,8 @@ class KnowledgeRetrievalNode(BaseNode):
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
previous_node_id: str | None = None,
thread_pool_id: str | None = None,
*,
llm_file_saver: LLMFileSaver | None = None,
):
@ -128,7 +128,7 @@ class KnowledgeRetrievalNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -137,7 +137,7 @@ class KnowledgeRetrievalNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -419,7 +419,7 @@ class KnowledgeRetrievalNode(BaseNode):
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
@ -576,7 +576,7 @@ class KnowledgeRetrievalNode(BaseNode):
return automatic_metadata_filters
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
):
if value is None and condition not in ("empty", "not empty"):
return

View File

@ -1,5 +1,5 @@
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Optional, TypeAlias, TypeVar
from typing import Any, TypeAlias, TypeVar
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
@ -44,7 +44,7 @@ class ListOperatorNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ListOperatorNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -53,7 +53,7 @@ class ListOperatorNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal, Optional
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
@ -18,7 +18,7 @@ class ModelConfig(BaseModel):
class ContextConfig(BaseModel):
enabled: bool
variable_selector: Optional[list[str]] = None
variable_selector: list[str] | None = None
class VisionConfigOptions(BaseModel):
@ -51,18 +51,18 @@ class PromptConfig(BaseModel):
class LLMNodeChatModelMessage(ChatModelMessage):
text: str = ""
jinja2_text: Optional[str] = None
jinja2_text: str | None = None
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
jinja2_text: Optional[str] = None
jinja2_text: str | None = None
class LLMNodeData(BaseNodeData):
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
memory: Optional[MemoryConfig] = None
memory: MemoryConfig | None = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: Mapping[str, Any] | None = None

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Optional, cast
from typing import cast
from sqlalchemy import select, update
from sqlalchemy.orm import Session
@ -86,8 +86,8 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
def fetch_memory(
variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
) -> TokenBufferMemory | None:
if not node_data_memory:
return None

View File

@ -4,7 +4,7 @@ import json
import logging
import re
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Literal, Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
@ -116,8 +116,8 @@ class LLMNode(BaseNode):
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
previous_node_id: str | None = None,
thread_pool_id: str | None = None,
*,
llm_file_saver: LLMFileSaver | None = None,
):
@ -143,7 +143,7 @@ class LLMNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LLMNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -152,7 +152,7 @@ class LLMNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -166,7 +166,7 @@ class LLMNode(BaseNode):
return "1"
def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
node_inputs: Optional[dict[str, Any]] = None
node_inputs: dict[str, Any] | None = None
process_data = None
result_text = ""
usage = LLMUsage.empty_usage()
@ -353,10 +353,10 @@ class LLMNode(BaseNode):
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
stop: Sequence[str] | None = None,
user_id: str,
structured_output_enabled: bool,
structured_output: Optional[Mapping[str, Any]] = None,
structured_output: Mapping[str, Any] | None = None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
@ -708,7 +708,7 @@ class LLMNode(BaseNode):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
if isinstance(prompt_template, list):
@ -951,7 +951,7 @@ class LLMNode(BaseNode):
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None):
def get_default_config(cls, filters: dict | None = None):
return {
"type": "llm",
"config": {
@ -979,7 +979,7 @@ class LLMNode(BaseNode):
def handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
@ -1174,7 +1174,7 @@ class LLMNode(BaseNode):
def _combine_message_content_with_role(
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
):
match role:
case PromptMessageRole.USER:
@ -1280,7 +1280,7 @@ def _handle_memory_completion_mode(
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: Optional[str],
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
) -> Sequence[PromptMessage]:

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Annotated, Any, Literal, Optional
from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field
@ -35,7 +35,7 @@ class LoopVariableData(BaseModel):
label: str
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: Optional[Any | list[str]] = None
value: Any | list[str] | None = None
class LoopNodeData(BaseLoopNodeData):
@ -46,8 +46,8 @@ class LoopNodeData(BaseLoopNodeData):
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
outputs: Optional[Mapping[str, Any]] = None
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
outputs: Mapping[str, Any] | None = None
class LoopStartNodeData(BaseNodeData):
@ -72,7 +72,7 @@ class LoopState(BaseLoopState):
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Optional[Any] = None
current_output: Any | None = None
class MetaData(BaseLoopState.MetaData):
"""
@ -81,7 +81,7 @@ class LoopState(BaseLoopState):
loop_length: int
def get_last_output(self) -> Optional[Any]:
def get_last_output(self) -> Any | None:
"""
Get last output.
"""
@ -89,7 +89,7 @@ class LoopState(BaseLoopState):
return self.outputs[-1]
return None
def get_current_output(self) -> Optional[Any]:
def get_current_output(self) -> Any | None:
"""
Get current output.
"""

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -21,7 +21,7 @@ class LoopEndNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopEndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -30,7 +30,7 @@ class LoopEndNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -3,7 +3,7 @@ import logging
import time
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from typing import TYPE_CHECKING, Any, Literal, cast
from configs import dify_config
from core.variables import (
@ -57,7 +57,7 @@ class LoopNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -66,7 +66,7 @@ class LoopNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -21,7 +21,7 @@ class LoopStartNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -30,7 +30,7 @@ class LoopStartNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,4 +1,4 @@
from typing import Annotated, Any, Literal, Optional
from typing import Annotated, Any, Literal
from pydantic import (
BaseModel,
@ -50,7 +50,7 @@ class ParameterConfig(BaseModel):
name: str
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
options: Optional[list[str]] = None
options: list[str] | None = None
description: str
required: bool
@ -88,8 +88,8 @@ class ParameterExtractorNodeData(BaseNodeData):
model: ModelConfig
query: list[str]
parameters: list[ParameterConfig]
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None
instruction: str | None = None
memory: MemoryConfig | None = None
reasoning_mode: Literal["function_call", "prompt"]
vision: VisionConfig = Field(default_factory=VisionConfig)

View File

@ -3,7 +3,7 @@ import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import File
@ -98,7 +98,7 @@ class ParameterExtractorNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ParameterExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -107,7 +107,7 @@ class ParameterExtractorNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -116,11 +116,11 @@ class ParameterExtractorNode(BaseNode):
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
_model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
_model_instance: ModelInstance | None = None
_model_config: ModelConfigWithCredentialsEntity | None = None
@classmethod
def get_default_config(cls, filters: Optional[dict] = None):
def get_default_config(cls, filters: dict | None = None):
return {
"model": {
"prompt_templates": {
@ -295,7 +295,7 @@ class ParameterExtractorNode(BaseNode):
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
stop: list[str],
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params,
@ -330,9 +330,9 @@ class ParameterExtractorNode(BaseNode):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
"""
Generate function call prompt.
@ -412,9 +412,9 @@ class ParameterExtractorNode(BaseNode):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate prompt engineering prompt.
@ -450,9 +450,9 @@ class ParameterExtractorNode(BaseNode):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate completion prompt.
@ -484,9 +484,9 @@ class ParameterExtractorNode(BaseNode):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate chat prompt.
@ -657,7 +657,7 @@ class ParameterExtractorNode(BaseNode):
return transformed_result
def _extract_complete_json_response(self, result: str) -> Optional[dict]:
def _extract_complete_json_response(self, result: str) -> dict | None:
"""
Extract complete json response.
"""
@ -672,7 +672,7 @@ class ParameterExtractorNode(BaseNode):
logger.info("extra error: %s", result)
return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
"""
Extract json from tool call.
"""
@ -711,7 +711,7 @@ class ParameterExtractorNode(BaseNode):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode)
@ -738,7 +738,7 @@ class ParameterExtractorNode(BaseNode):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
@ -774,7 +774,7 @@ class ParameterExtractorNode(BaseNode):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str],
context: str | None,
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)

View File

@ -1,5 +1,3 @@
from typing import Optional
from pydantic import BaseModel, Field
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
@ -16,8 +14,8 @@ class QuestionClassifierNodeData(BaseNodeData):
query_variable_selector: list[str]
model: ModelConfig
classes: list[ClassConfig]
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None
instruction: str | None = None
memory: MemoryConfig | None = None
vision: VisionConfig = Field(default_factory=VisionConfig)
@property

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@ -59,8 +59,8 @@ class QuestionClassifierNode(BaseNode):
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
previous_node_id: str | None = None,
thread_pool_id: str | None = None,
*,
llm_file_saver: LLMFileSaver | None = None,
):
@ -86,7 +86,7 @@ class QuestionClassifierNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = QuestionClassifierNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -95,7 +95,7 @@ class QuestionClassifierNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -275,7 +275,7 @@ class QuestionClassifierNode(BaseNode):
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None):
def get_default_config(cls, filters: dict | None = None):
"""
Get default config of node.
:param filters: filter by node config parameters.
@ -288,7 +288,7 @@ class QuestionClassifierNode(BaseNode):
node_data: QuestionClassifierNodeData,
query: str,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str],
context: str | None,
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
@ -331,7 +331,7 @@ class QuestionClassifierNode(BaseNode):
self,
node_data: QuestionClassifierNodeData,
query: str,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
@ -18,7 +18,7 @@ class StartNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = StartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -27,7 +27,7 @@ class StartNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,6 +1,6 @@
import os
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from typing import Any
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.entities.node_entities import NodeRunResult
@ -21,7 +21,7 @@ class TemplateTransformNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = TemplateTransformNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -30,7 +30,7 @@ class TemplateTransformNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -40,7 +40,7 @@ class TemplateTransformNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None):
def get_default_config(cls, filters: dict | None = None):
"""
Get default config of node.
:param filters: filter by node config parameters.

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -439,7 +439,7 @@ class ToolNode(BaseNode):
return result
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -448,7 +448,7 @@ class ToolNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,3 @@
from typing import Optional
from pydantic import BaseModel
from core.variables.types import SegmentType
@ -33,4 +31,4 @@ class VariableAssignerNodeData(BaseNodeData):
type: str = "variable-assigner"
output_type: str
variables: list[list[str]]
advanced_settings: Optional[AdvancedSettings] = None
advanced_settings: AdvancedSettings | None = None

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.variables.segments import Segment
from core.workflow.entities.node_entities import NodeRunResult
@ -18,7 +18,7 @@ class VariableAggregatorNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -27,7 +27,7 @@ class VariableAggregatorNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,5 @@
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
from typing import TYPE_CHECKING, Any, TypeAlias
from core.variables import SegmentType, Variable
from core.variables.segments import BooleanSegment
@ -33,7 +33,7 @@ class VariableAssignerNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -42,7 +42,7 @@ class VariableAssignerNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -58,8 +58,8 @@ class VariableAssignerNode(BaseNode):
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
previous_node_id: str | None = None,
thread_pool_id: str | None = None,
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
):
super().__init__(

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
@ -61,7 +61,7 @@ class VariableAssignerNode(BaseNode):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -70,7 +70,7 @@ class VariableAssignerNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]: