Merge branch 'main' into feat/memory-orchestration-be

# Conflicts:
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/prompt/entities/advanced_prompt_entities.py
#	api/core/variables/segments.py
This commit is contained in:
Stream
2025-09-15 14:14:56 +08:00
2025 changed files with 67244 additions and 18565 deletions

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from packaging.version import Version
from pydantic import ValidationError
@ -66,10 +66,10 @@ class AgentNode(BaseNode):
_node_type = NodeType.AGENT
_node_data: AgentNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AgentNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -78,7 +78,7 @@ class AgentNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -153,7 +153,7 @@ class AgentNode(BaseNode):
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
"agent_strategy": self._node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=self.user_id,
@ -394,15 +394,14 @@ class AgentNode(BaseNode):
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}"
== cast(AgentNodeData, self._node_data).agent_strategy_provider_name
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:
icon = None
return icon
def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value]

View File

@ -1,4 +1,4 @@
from enum import Enum, StrEnum
from enum import IntEnum, StrEnum, auto
from typing import Any, Literal, Union
from pydantic import BaseModel
@ -25,9 +25,9 @@ class AgentNodeData(BaseNodeData):
agent_parameters: dict[str, AgentInput]
class ParamsAutoGenerated(Enum):
CLOSE = 0
OPEN = 1
class ParamsAutoGenerated(IntEnum):
CLOSE = auto()
OPEN = auto()
class AgentOldVersionModelFeatures(StrEnum):
@ -38,8 +38,8 @@ class AgentOldVersionModelFeatures(StrEnum):
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
VISION = auto()
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
DOCUMENT = auto()
VIDEO = auto()
AUDIO = auto()

View File

@ -1,6 +1,3 @@
from typing import Optional
class AgentNodeError(Exception):
"""Base exception for all agent node errors."""
@ -12,7 +9,7 @@ class AgentNodeError(Exception):
class AgentStrategyError(AgentNodeError):
"""Exception raised when there's an error with the agent strategy."""
def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None):
self.strategy_name = strategy_name
self.provider_name = provider_name
super().__init__(message)
@ -21,7 +18,7 @@ class AgentStrategyError(AgentNodeError):
class AgentStrategyNotFoundError(AgentStrategyError):
"""Exception raised when the specified agent strategy is not found."""
def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
def __init__(self, strategy_name: str, provider_name: str | None = None):
super().__init__(
f"Agent strategy '{strategy_name}' not found"
+ (f" for provider '{provider_name}'" if provider_name else ""),
@ -33,7 +30,7 @@ class AgentStrategyNotFoundError(AgentStrategyError):
class AgentInvocationError(AgentNodeError):
"""Exception raised when there's an error invoking the agent."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
@ -41,7 +38,7 @@ class AgentInvocationError(AgentNodeError):
class AgentParameterError(AgentNodeError):
"""Exception raised when there's an error with agent parameters."""
def __init__(self, message: str, parameter_name: Optional[str] = None):
def __init__(self, message: str, parameter_name: str | None = None):
self.parameter_name = parameter_name
super().__init__(message)
@ -49,7 +46,7 @@ class AgentParameterError(AgentNodeError):
class AgentVariableError(AgentNodeError):
"""Exception raised when there's an error with variables in the agent node."""
def __init__(self, message: str, variable_name: Optional[str] = None):
def __init__(self, message: str, variable_name: str | None = None):
self.variable_name = variable_name
super().__init__(message)
@ -71,7 +68,7 @@ class AgentInputTypeError(AgentNodeError):
class ToolFileError(AgentNodeError):
"""Exception raised when there's an error with a tool file."""
def __init__(self, message: str, file_id: Optional[str] = None):
def __init__(self, message: str, file_id: str | None = None):
self.file_id = file_id
super().__init__(message)
@ -86,7 +83,7 @@ class ToolFileNotFoundError(ToolFileError):
class AgentMessageTransformError(AgentNodeError):
"""Exception raised when there's an error transforming agent messages."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
@ -94,7 +91,7 @@ class AgentMessageTransformError(AgentNodeError):
class AgentModelError(AgentNodeError):
"""Exception raised when there's an error with the model used by the agent."""
def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
def __init__(self, message: str, model_name: str | None = None, provider: str | None = None):
self.model_name = model_name
self.provider = provider
super().__init__(message)
@ -103,7 +100,7 @@ class AgentModelError(AgentNodeError):
class AgentMemoryError(AgentNodeError):
"""Exception raised when there's an error with the agent's memory."""
def __init__(self, message: str, conversation_id: Optional[str] = None):
def __init__(self, message: str, conversation_id: str | None = None):
self.conversation_id = conversation_id
super().__init__(message)
@ -114,9 +111,9 @@ class AgentVariableTypeError(AgentNodeError):
def __init__(
self,
message: str,
variable_name: Optional[str] = None,
expected_type: Optional[str] = None,
actual_type: Optional[str] = None,
variable_name: str | None = None,
expected_type: str | None = None,
actual_type: str | None = None,
):
self.variable_name = variable_name
self.expected_type = expected_type

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from core.variables import ArrayFileSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
@ -22,10 +22,10 @@ class AnswerNode(BaseNode):
_node_data: AnswerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AnswerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -34,7 +34,7 @@ class AnswerNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -134,7 +134,7 @@ class AnswerStreamGeneratorRouter:
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]],
) -> None:
):
"""
Recursive fetch answer dependencies
:param current_node_id: current node id

View File

@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
class AnswerStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
def __init__(self, graph: Graph, variable_pool: VariablePool):
super().__init__(graph, variable_pool)
self.generate_routes = graph.answer_stream_generate_routes
self.route_position = {}
@ -52,12 +52,12 @@ class AnswerStreamProcessor(StreamProcessor):
yield event
elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: # ty: ignore [unresolved-attribute]
# update self.route_position after all stream event finished
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: # ty: ignore [unresolved-attribute]
self.route_position[answer_node_id] += 1
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] # ty: ignore [unresolved-attribute]
self._remove_unreachable_nodes(event)
@ -66,9 +66,9 @@ class AnswerStreamProcessor(StreamProcessor):
else:
yield event
def reset(self) -> None:
def reset(self):
self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
for answer_node_id, _ in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
@ -149,9 +149,6 @@ class AnswerStreamProcessor(StreamProcessor):
return []
stream_output_value_selector = event.from_variable_selector
if not stream_output_value_selector:
return []
stream_out_answer_node_ids = []
for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:

View File

@ -1,7 +1,6 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Optional
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
@ -11,7 +10,7 @@ logger = logging.getLogger(__name__)
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
def __init__(self, graph: Graph, variable_pool: VariablePool):
self.graph = graph
self.variable_pool = variable_pool
self.rest_node_ids = graph.node_ids.copy()
@ -20,7 +19,7 @@ class StreamProcessor(ABC):
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent):
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
@ -72,7 +71,7 @@ class StreamProcessor(ABC):
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: str | None = None) -> list[str]:
if node_id not in self.rest_node_ids:
self.rest_node_ids.append(node_id)
node_ids = []
@ -89,7 +88,7 @@ class StreamProcessor(ABC):
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]):
"""
remove target node ids until merge
"""

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from enum import StrEnum, auto
from pydantic import BaseModel, Field
@ -19,9 +19,9 @@ class GenerateRouteChunk(BaseModel):
Generate Route Chunk.
"""
class ChunkType(Enum):
VAR = "var"
TEXT = "text"
class ChunkType(StrEnum):
VAR = auto()
TEXT = auto()
type: ChunkType = Field(..., description="generate route chunk type")

View File

@ -1,7 +1,7 @@
import json
from abc import ABC
from enum import StrEnum
from typing import Any, Optional, Union
from typing import Any, Union
from pydantic import BaseModel, model_validator
@ -23,12 +23,12 @@ NumberType = Union[int, float]
class DefaultValue(BaseModel):
value: Any
value: Any = None
type: DefaultValueType
key: str
@staticmethod
def _parse_json(value: str) -> Any:
def _parse_json(value: str):
"""Unified JSON parsing handler"""
try:
return json.loads(value)
@ -121,10 +121,10 @@ class RetryConfig(BaseModel):
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
desc: str | None = None
version: str = "1"
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
error_strategy: ErrorStrategy | None = None
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = RetryConfig()
@property
@ -135,7 +135,7 @@ class BaseNodeData(ABC, BaseModel):
class BaseIterationNodeData(BaseNodeData):
start_node_id: Optional[str] = None
start_node_id: str | None = None
class BaseIterationState(BaseModel):
@ -150,7 +150,7 @@ class BaseIterationState(BaseModel):
class BaseLoopNodeData(BaseNodeData):
start_node_id: Optional[str] = None
start_node_id: str | None = None
class BaseLoopState(BaseModel):

View File

@ -1,7 +1,7 @@
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
from typing import TYPE_CHECKING, Any, ClassVar, Union
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -26,9 +26,9 @@ class BaseNode:
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
) -> None:
previous_node_id: str | None = None,
thread_pool_id: str | None = None,
):
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
@ -51,7 +51,7 @@ class BaseNode:
self.node_id = node_id
@abstractmethod
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
def init_node_data(self, data: Mapping[str, Any]): ...
@abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
@ -141,7 +141,7 @@ class BaseNode:
return {}
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: dict | None = None):
return {}
@property
@ -170,7 +170,7 @@ class BaseNode:
# to BaseNodeData properties in a type-safe way
@abstractmethod
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
...
@ -185,7 +185,7 @@ class BaseNode:
...
@abstractmethod
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
"""Get the node description."""
...
@ -201,7 +201,7 @@ class BaseNode:
# Public interface properties that delegate to abstract methods
@property
def error_strategy(self) -> Optional[ErrorStrategy]:
def error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
return self._get_error_strategy()
@ -216,7 +216,7 @@ class BaseNode:
return self._get_title()
@property
def description(self) -> Optional[str]:
def description(self) -> str | None:
"""Get the node description."""
return self._get_description()

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
from typing import Any, Optional
from typing import Any
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
@ -8,6 +8,7 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.variables.segments import ArrayFileSegment
from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@ -27,10 +28,10 @@ class CodeNode(BaseNode):
_node_data: CodeNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = CodeNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -39,7 +40,7 @@ class CodeNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -49,7 +50,7 @@ class CodeNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: dict | None = None):
"""
Get default config of node.
:param filters: filter by node config parameters.
@ -119,6 +120,14 @@ class CodeNode(BaseNode):
return value.replace("\x00", "")
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
if value is None:
return None
if not isinstance(value, bool):
raise OutputValidationError(f"Output variable `{variable}` must be a boolean")
return value
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
"""
Check number
@ -152,7 +161,7 @@ class CodeNode(BaseNode):
def _transform_result(
self,
result: Mapping[str, Any],
output_schema: Optional[dict[str, CodeNodeData.Output]],
output_schema: dict[str, CodeNodeData.Output] | None,
prefix: str = "",
depth: int = 1,
):
@ -173,6 +182,8 @@ class CodeNode(BaseNode):
prefix=f"{prefix}.{output_name}" if prefix else output_name,
depth=depth + 1,
)
elif isinstance(output_value, bool):
self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name)
elif isinstance(output_value, int | float):
self._check_number(
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
@ -232,7 +243,7 @@ class CodeNode(BaseNode):
if output_name not in result:
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
if output_config.type == "object":
if output_config.type == SegmentType.OBJECT:
# check if output is object
if not isinstance(result.get(output_name), dict):
if result[output_name] is None:
@ -249,18 +260,28 @@ class CodeNode(BaseNode):
prefix=f"{prefix}.{output_name}",
depth=depth + 1,
)
elif output_config.type == "number":
elif output_config.type == SegmentType.NUMBER:
# check if number available
transformed_result[output_name] = self._check_number(
value=result[output_name], variable=f"{prefix}{dot}{output_name}"
)
elif output_config.type == "string":
checked = self._check_number(value=result[output_name], variable=f"{prefix}{dot}{output_name}")
# If the output is a boolean and the output schema specifies a NUMBER type,
# convert the boolean value to an integer.
#
# This ensures compatibility with existing workflows that may use
# `True` and `False` as values for NUMBER type outputs.
transformed_result[output_name] = self._convert_boolean_to_int(checked)
elif output_config.type == SegmentType.STRING:
# check if string available
transformed_result[output_name] = self._check_string(
value=result[output_name],
variable=f"{prefix}{dot}{output_name}",
)
elif output_config.type == "array[number]":
elif output_config.type == SegmentType.BOOLEAN:
transformed_result[output_name] = self._check_boolean(
value=result[output_name],
variable=f"{prefix}{dot}{output_name}",
)
elif output_config.type == SegmentType.ARRAY_NUMBER:
# check if array of number available
if not isinstance(result[output_name], list):
if result[output_name] is None:
@ -278,10 +299,17 @@ class CodeNode(BaseNode):
)
transformed_result[output_name] = [
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
# If the element is a boolean and the output schema specifies a `array[number]` type,
# convert the boolean value to an integer.
#
# This ensures compatibility with existing workflows that may use
# `True` and `False` as values for NUMBER type outputs.
self._convert_boolean_to_int(
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]"),
)
for i, value in enumerate(result[output_name])
]
elif output_config.type == "array[string]":
elif output_config.type == SegmentType.ARRAY_STRING:
# check if array of string available
if not isinstance(result[output_name], list):
if result[output_name] is None:
@ -302,7 +330,7 @@ class CodeNode(BaseNode):
self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
for i, value in enumerate(result[output_name])
]
elif output_config.type == "array[object]":
elif output_config.type == SegmentType.ARRAY_OBJECT:
# check if array of object available
if not isinstance(result[output_name], list):
if result[output_name] is None:
@ -340,6 +368,22 @@ class CodeNode(BaseNode):
)
for i, value in enumerate(result[output_name])
]
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
# check if array of object available
if not isinstance(result[output_name], list):
if result[output_name] is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
transformed_result[output_name] = [
self._check_boolean(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
for i, value in enumerate(result[output_name])
]
else:
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
@ -374,3 +418,16 @@ class CodeNode(BaseNode):
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
@staticmethod
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
"""This function convert boolean to integers when the output schema specifies a NUMBER type.
This ensures compatibility with existing workflows that may use
`True` and `False` as values for NUMBER type outputs.
"""
if value is None:
return None
if isinstance(value, bool):
return int(value)
return value

View File

@ -1,11 +1,31 @@
from typing import Literal, Optional
from typing import Annotated, Literal
from pydantic import BaseModel
from pydantic import AfterValidator, BaseModel
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.types import SegmentType
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
[
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.OBJECT,
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
]
)
def _validate_type(segment_type: SegmentType) -> SegmentType:
if segment_type not in _ALLOWED_OUTPUT_FROM_CODE:
raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}")
return segment_type
class CodeNodeData(BaseNodeData):
"""
@ -13,8 +33,8 @@ class CodeNodeData(BaseNodeData):
"""
class Output(BaseModel):
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
children: Optional[dict[str, "CodeNodeData.Output"]] = None
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: dict[str, "CodeNodeData.Output"] | None = None
class Dependency(BaseModel):
name: str
@ -24,4 +44,4 @@ class CodeNodeData(BaseNodeData):
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
code: str
outputs: dict[str, Output]
dependencies: Optional[list[Dependency]] = None
dependencies: list[Dependency] | None = None

View File

@ -5,7 +5,7 @@ import logging
import os
import tempfile
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any
import chardet
import docx
@ -47,10 +47,10 @@ class DocumentExtractorNode(BaseNode):
_node_data: DocumentExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = DocumentExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -59,7 +59,7 @@ class DocumentExtractorNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -302,12 +302,12 @@ def _extract_text_from_yaml(file_content: bytes) -> str:
encoding = "utf-8"
yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore"))
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e:
# If decoding fails, try with utf-8 as last resort
try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore"))
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, yaml.YAMLError):
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
@ -428,9 +428,9 @@ def _download_file_content(file: File) -> bytes:
raise FileDownloadError("Missing URL for remote file")
response = ssrf_proxy.get(file.remote_url)
response.raise_for_status()
return cast(bytes, response.content)
return response.content
else:
return cast(bytes, file_manager.download(file))
return file_manager.download(file)
except Exception as e:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
@ -515,14 +515,14 @@ def _extract_text_from_excel(file_content: bytes) -> str:
df.dropna(how="all", inplace=True)
# Combine multi-line text in each cell into a single line
df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore
df = df.map(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x)
# Combine multi-line text in column names into a single line
df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns])
# Manually construct the Markdown table
markdown_table += _construct_markdown_table(df) + "\n\n"
except Exception as e:
except Exception:
continue
return markdown_table
except Exception as e:

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -14,10 +14,10 @@ class EndNode(BaseNode):
_node_data: EndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = EndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -26,7 +26,7 @@ class EndNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -121,7 +121,7 @@ class EndStreamGeneratorRouter:
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
end_dependencies: dict[str, list[str]],
) -> None:
):
"""
Recursive fetch end dependencies
:param current_node_id: current node id

View File

@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
class EndStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
def __init__(self, graph: Graph, variable_pool: VariablePool):
super().__init__(graph, variable_pool)
self.end_stream_param = graph.end_stream_param
self.route_position = {}
@ -76,7 +76,7 @@ class EndStreamProcessor(StreamProcessor):
else:
yield event
def reset(self) -> None:
def reset(self):
self.route_position = {}
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
self.route_position[end_node_id] = 0

View File

@ -30,6 +30,7 @@ class ModelInvokeCompletedEvent(BaseModel):
text: str
usage: LLMUsage
finish_reason: str | None = None
reasoning_content: str | None = None
class RunRetryEvent(BaseModel):

View File

@ -1,7 +1,7 @@
import mimetypes
from collections.abc import Sequence
from email.message import Message
from typing import Any, Literal, Optional
from typing import Any, Literal
import httpx
from pydantic import BaseModel, Field, ValidationInfo, field_validator
@ -18,7 +18,7 @@ class HttpRequestNodeAuthorizationConfig(BaseModel):
class HttpRequestNodeAuthorization(BaseModel):
type: Literal["no-auth", "api-key"]
config: Optional[HttpRequestNodeAuthorizationConfig] = None
config: HttpRequestNodeAuthorizationConfig | None = None
@field_validator("config", mode="before")
@classmethod
@ -88,9 +88,9 @@ class HttpRequestNodeData(BaseNodeData):
authorization: HttpRequestNodeAuthorization
headers: str
params: str
body: Optional[HttpRequestNodeBody] = None
timeout: Optional[HttpRequestNodeTimeout] = None
ssl_verify: Optional[bool] = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
body: HttpRequestNodeBody | None = None
timeout: HttpRequestNodeTimeout | None = None
ssl_verify: bool | None = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
class Response:
@ -183,7 +183,7 @@ class Response:
return f"{(self.size / 1024 / 1024):.2f} MB"
@property
def parsed_content_disposition(self) -> Optional[Message]:
def parsed_content_disposition(self) -> Message | None:
content_disposition = self.headers.get("content-disposition", "")
if content_disposition:
msg = Message()

View File

@ -329,22 +329,16 @@ class Executor:
"""
do http request depending on api bundle
"""
if self.method not in {
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
}:
_METHOD_MAP = {
"get": ssrf_proxy.get,
"head": ssrf_proxy.head,
"post": ssrf_proxy.post,
"put": ssrf_proxy.put,
"delete": ssrf_proxy.delete,
"patch": ssrf_proxy.patch,
}
method_lc = self.method.lower()
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
@ -362,11 +356,11 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
response: httpx.Response = _METHOD_MAP[method_lc](**request_args)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response # type: ignore
return response
def invoke(self) -> Response:
# assemble headers

View File

@ -1,7 +1,7 @@
import logging
import mimetypes
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from typing import Any
from configs import dify_config
from core.file import File, FileTransferMethod
@ -38,10 +38,10 @@ class HttpRequestNode(BaseNode):
_node_data: HttpRequestNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = HttpRequestNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -50,7 +50,7 @@ class HttpRequestNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -60,7 +60,7 @@ class HttpRequestNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
def get_default_config(cls, filters: dict[str, Any] | None = None):
return {
"type": "http-request",
"config": {

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel, Field
@ -20,7 +20,7 @@ class IfElseNodeData(BaseNodeData):
logical_operator: Literal["and", "or"]
conditions: list[Condition]
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
cases: Optional[list[Case]] = None
cases: list[Case] | None = None

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal, Optional
from typing import Any, Literal
from typing_extensions import deprecated
@ -19,10 +19,10 @@ class IfElseNode(BaseNode):
_node_data: IfElseNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IfElseNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -31,7 +31,7 @@ class IfElseNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -83,7 +83,7 @@ class IfElseNode(BaseNode):
else:
# TODO: Update database then remove this
# Fallback to old structure if cases are not defined
input_conditions, group_result, final_result = _should_not_use_old_function(
input_conditions, group_result, final_result = _should_not_use_old_function( # ty: ignore [deprecated]
condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool,
conditions=self._node_data.conditions or [],

View File

@ -1,5 +1,5 @@
from enum import StrEnum
from typing import Any, Optional
from typing import Any
from pydantic import Field
@ -17,7 +17,7 @@ class IterationNodeData(BaseIterationNodeData):
Iteration Node Data.
"""
parent_loop_id: Optional[str] = None # redundant field, not used currently
parent_loop_id: str | None = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
is_parallel: bool = False # open the parallel mode or not
@ -39,7 +39,7 @@ class IterationState(BaseIterationState):
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Optional[Any] = None
current_output: Any | None = None
class MetaData(BaseIterationState.MetaData):
"""
@ -48,7 +48,7 @@ class IterationState(BaseIterationState):
iterator_length: int
def get_last_output(self) -> Optional[Any]:
def get_last_output(self) -> Any | None:
"""
Get last output.
"""
@ -56,7 +56,7 @@ class IterationState(BaseIterationState):
return self.outputs[-1]
return None
def get_current_output(self) -> Optional[Any]:
def get_current_output(self) -> Any | None:
"""
Get current output.
"""

View File

@ -6,12 +6,12 @@ from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait
from datetime import datetime
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, cast
from flask import Flask, current_app
from configs import dify_config
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import (
NodeRunResult,
@ -67,10 +67,10 @@ class IterationNode(BaseNode):
_node_data: IterationNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -79,7 +79,7 @@ class IterationNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -89,7 +89,7 @@ class IterationNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: dict | None = None):
return {
"type": "iteration",
"config": {
@ -112,10 +112,10 @@ class IterationNode(BaseNode):
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
if isinstance(variable, NoneSegment) or len(variable.value) == 0:
# Try our best to preserve the type informat.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
@ -424,7 +424,7 @@ class IterationNode(BaseNode):
graph_engine: "GraphEngine",
iteration_graph: Graph,
iter_run_map: dict[str, float],
parallel_mode_run_id: Optional[str] = None,
parallel_mode_run_id: str | None = None,
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration
@ -441,8 +441,8 @@ class IterationNode(BaseNode):
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
next_index = int(current_index) + 1
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: # ty: ignore [unresolved-attribute]
event.in_iteration_id = self.node_id # ty: ignore [unresolved-attribute]
if (
isinstance(event, BaseNodeEvent)

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -18,10 +18,10 @@ class IterationStartNode(BaseNode):
_node_data: IterationStartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -30,7 +30,7 @@ class IterationStartNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel, Field
@ -49,11 +49,11 @@ class MultipleRetrievalConfig(BaseModel):
"""
top_k: int
score_threshold: Optional[float] = None
score_threshold: float | None = None
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: Optional[RerankingModelConfig] = None
weights: Optional[WeightedScoreConfig] = None
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class SingleRetrievalConfig(BaseModel):
@ -104,8 +104,8 @@ class MetadataFilteringCondition(BaseModel):
Metadata Filtering Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class KnowledgeRetrievalNodeData(BaseNodeData):
@ -117,11 +117,11 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
query_variable_selector: list[str]
dataset_ids: list[str]
retrieval_mode: Literal["single", "multiple"]
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
single_retrieval_config: Optional[SingleRetrievalConfig] = None
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
multiple_retrieval_config: MultipleRetrievalConfig | None = None
single_retrieval_config: SingleRetrievalConfig | None = None
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
metadata_model_config: ModelConfig | None = None
metadata_filtering_conditions: MetadataFilteringCondition | None = None
vision: VisionConfig = Field(default_factory=VisionConfig)
@property

View File

@ -4,9 +4,9 @@ import re
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import Float, and_, func, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import sessionmaker
@ -78,7 +78,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}
@ -101,11 +101,11 @@ class KnowledgeRetrievalNode(BaseNode):
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
previous_node_id: str | None = None,
thread_pool_id: str | None = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
):
super().__init__(
id=id,
config=config,
@ -125,10 +125,10 @@ class KnowledgeRetrievalNode(BaseNode):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -137,7 +137,7 @@ class KnowledgeRetrievalNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -259,7 +259,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# fetch model config
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
@ -291,7 +291,7 @@ class KnowledgeRetrievalNode(BaseNode):
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
@ -367,15 +367,12 @@ class KnowledgeRetrievalNode(BaseNode):
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = (
db.session.query(Document)
.where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
document = db.session.scalar(stmt)
if dataset and document:
source = {
"metadata": {
@ -422,7 +419,7 @@ class KnowledgeRetrievalNode(BaseNode):
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
@ -514,7 +511,8 @@ class KnowledgeRetrievalNode(BaseNode):
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(stmt).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required")
@ -573,12 +571,12 @@ class KnowledgeRetrievalNode(BaseNode):
"condition": item.get("comparison_operator"),
}
)
except Exception as e:
except Exception:
return []
return automatic_metadata_filters
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
):
if value is None and condition not in ("empty", "not empty"):
return

View File

@ -1,36 +1,43 @@
from collections.abc import Sequence
from typing import Literal
from enum import StrEnum
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
_Condition = Literal[
class FilterOperator(StrEnum):
# string conditions
"contains",
"start with",
"end with",
"is",
"in",
"empty",
"not contains",
"is not",
"not in",
"not empty",
CONTAINS = "contains"
START_WITH = "start with"
END_WITH = "end with"
IS = "is"
IN = "in"
EMPTY = "empty"
NOT_CONTAINS = "not contains"
IS_NOT = "is not"
NOT_IN = "not in"
NOT_EMPTY = "not empty"
# number conditions
"=",
"",
"<",
">",
"",
"",
]
EQUAL = "="
NOT_EQUAL = ""
LESS_THAN = "<"
GREATER_THAN = ">"
GREATER_THAN_OR_EQUAL = ""
LESS_THAN_OR_EQUAL = ""
class Order(StrEnum):
ASC = "asc"
DESC = "desc"
class FilterCondition(BaseModel):
key: str = ""
comparison_operator: _Condition = "contains"
value: str | Sequence[str] = ""
comparison_operator: FilterOperator = FilterOperator.CONTAINS
# the value is bool if the filter operator is comparing with
# a boolean constant.
value: str | Sequence[str] | bool = ""
class FilterBy(BaseModel):
@ -38,10 +45,10 @@ class FilterBy(BaseModel):
conditions: Sequence[FilterCondition] = Field(default_factory=list)
class OrderBy(BaseModel):
class OrderByConfig(BaseModel):
enabled: bool = False
key: str = ""
value: Literal["asc", "desc"] = "asc"
value: Order = Order.ASC
class Limit(BaseModel):
@ -57,6 +64,6 @@ class ExtractConfig(BaseModel):
class ListOperatorNodeData(BaseNodeData):
variable: Sequence[str] = Field(default_factory=list)
filter_by: FilterBy
order_by: OrderBy
order_by: OrderByConfig
limit: Limit
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)

View File

@ -1,28 +1,50 @@
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal, Optional, Union
from typing import Any, TypeAlias, TypeVar
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .entities import ListOperatorNodeData
from .entities import FilterOperator, ListOperatorNodeData, Order
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
_SUPPORTED_TYPES_TUPLE = (
ArrayFileSegment,
ArrayNumberSegment,
ArrayStringSegment,
ArrayBooleanSegment,
)
_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment
_T = TypeVar("_T")
def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
"""Returns the negation of a given filter function. If the original filter
returns `True` for a value, the negated filter will return `False`, and vice versa.
"""
def wrapper(value: _T) -> bool:
return not filter_(value)
return wrapper
class ListOperatorNode(BaseNode):
_node_type = NodeType.LIST_OPERATOR
_node_data: ListOperatorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ListOperatorNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -31,7 +53,7 @@ class ListOperatorNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -45,8 +67,8 @@ class ListOperatorNode(BaseNode):
return "1"
def _run(self):
inputs: dict[str, list] = {}
process_data: dict[str, list] = {}
inputs: dict[str, Sequence[object]] = {}
process_data: dict[str, Sequence[object]] = {}
outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
@ -69,11 +91,8 @@ class ListOperatorNode(BaseNode):
process_data=process_data,
outputs=outputs,
)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
)
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
@ -122,9 +141,7 @@ class ListOperatorNode(BaseNode):
outputs=outputs,
)
def _apply_filter(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
filter_func: Callable[[Any], bool]
result: list[Any] = []
for condition in self._node_data.filter_by.conditions:
@ -154,33 +171,35 @@ class ListOperatorNode(BaseNode):
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayBooleanSegment):
if not isinstance(condition.value, bool):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
else:
raise AssertionError("this statment should be unreachable.")
return variable
def _apply_order(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self._node_data.order_by.value, array=variable.value)
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
else:
raise AssertionError("this statement should be unreachable")
return variable
def _apply_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
result = variable.value[: self._node_data.limit.size]
return variable.model_copy(update={"value": result})
def _extract_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
@ -232,11 +251,11 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
case "empty":
return lambda x: x == ""
case "not contains":
return lambda x: not _contains(value)(x)
return _negation(_contains(value))
case "is not":
return lambda x: not _is(value)(x)
return _negation(_is(value))
case "not in":
return lambda x: not _in(value)(x)
return _negation(_in(value))
case "not empty":
return lambda x: x != ""
case _:
@ -248,7 +267,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
case "in":
return _in(value)
case "not in":
return lambda x: not _in(value)(x)
return _negation(_in(value))
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
@ -271,6 +290,16 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]:
match condition:
case FilterOperator.IS:
return _is(value)
case FilterOperator.IS_NOT:
return _negation(_is(value))
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
@ -298,7 +327,7 @@ def _endswith(value: str) -> Callable[[str], bool]:
return lambda x: x.endswith(value)
def _is(value: str) -> Callable[[str], bool]:
def _is(value: _T) -> Callable[[_T], bool]:
return lambda x: x == value
@ -330,21 +359,13 @@ def _ge(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x >= value
def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]):
return sorted(array, key=lambda x: x, reverse=order == "desc")
def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
return sorted(array, key=lambda x: x, reverse=order == "desc")
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
extract_func: Callable[[File], Any]
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
extract_func = _get_file_extract_string_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
elif order_by == "size":
extract_func = _get_file_extract_number_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
else:
raise InvalidKeyError(f"Invalid order key: {order_by}")

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
@ -18,7 +18,7 @@ class ModelConfig(BaseModel):
class ContextConfig(BaseModel):
enabled: bool
variable_selector: Optional[list[str]] = None
variable_selector: list[str] | None = None
class VisionConfigOptions(BaseModel):
@ -51,23 +51,40 @@ class PromptConfig(BaseModel):
class LLMNodeChatModelMessage(ChatModelMessage):
text: str = ""
jinja2_text: Optional[str] = None
jinja2_text: str | None = None
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
jinja2_text: Optional[str] = None
jinja2_text: str | None = None
class LLMNodeData(BaseNodeData):
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
memory: Optional[MemoryConfig] = None
memory: MemoryConfig | None = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: Mapping[str, Any] | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
reasoning_format: Literal["separated", "tagged"] = Field(
# Keep tagged as default for backward compatibility
default="tagged",
description=(
"""
Strategy for handling model reasoning output.
separated: Return clean text (without <think> tags) + reasoning_content field.
Recommended for new workflows. Enables safe downstream parsing and
workflow variable access: {{#node_id.reasoning_content#}}
tagged : Return original text (with <think> tags) + reasoning_content field.
Maintains full backward compatibility while still providing reasoning_content
for workflow automation. Frontend thinking panels work as before.
"""
),
)
@field_validator("prompt_config", mode="before")
@classmethod

View File

@ -41,5 +41,5 @@ class FileTypeNotSupportError(LLMNodeError):
class UnsupportedPromptContentTypeError(LLMNodeError):
def __init__(self, *, type_name: str) -> None:
def __init__(self, *, type_name: str):
super().__init__(f"Prompt content type {type_name} is not supported.")

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Optional, cast
from typing import cast
from sqlalchemy import select, update
from sqlalchemy.orm import Session
@ -86,8 +86,8 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
def fetch_memory(
variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
) -> TokenBufferMemory | None:
if not node_data_memory:
return None
@ -107,7 +107,7 @@ def fetch_memory(
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration

View File

@ -2,8 +2,9 @@ import base64
import io
import json
import logging
import re
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Literal, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -59,7 +60,6 @@ from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
@ -96,6 +96,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.event import InNodeEvent
logger = logging.getLogger(__name__)
@ -105,6 +106,9 @@ class LLMNode(BaseNode):
_node_data: LLMNodeData
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
@ -118,11 +122,11 @@ class LLMNode(BaseNode):
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
previous_node_id: str | None = None,
thread_pool_id: str | None = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
):
super().__init__(
id=id,
config=config,
@ -142,10 +146,10 @@ class LLMNode(BaseNode):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LLMNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -154,7 +158,7 @@ class LLMNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -167,12 +171,13 @@ class LLMNode(BaseNode):
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
node_inputs: Optional[dict[str, Any]] = None
def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
node_inputs: dict[str, Any] | None = None
process_data = None
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
reasoning_content = None
variable_pool = self.graph_runtime_state.variable_pool
try:
@ -262,6 +267,7 @@ class LLMNode(BaseNode):
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
reasoning_format=self._node_data.reasoning_format,
)
structured_output: LLMStructuredOutput | None = None
@ -270,9 +276,20 @@ class LLMNode(BaseNode):
if isinstance(event, RunStreamChunkEvent):
yield event
elif isinstance(event, ModelInvokeCompletedEvent):
# Raw text
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
reasoning_content = event.reasoning_content or ""
# For downstream nodes, determine clean text based on reasoning_format
if self._node_data.reasoning_format == "tagged":
# Keep <think> tags for backward compatibility
clean_text = result_text
else:
# Extract clean text from <think> tags
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
@ -290,7 +307,12 @@ class LLMNode(BaseNode):
"model_name": model_config.model,
}
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
outputs = {
"text": clean_text,
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
}
if structured_output:
outputs["structured_output"] = structured_output.structured_output
if self._file_outputs is not None:
@ -342,13 +364,14 @@ class LLMNode(BaseNode):
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
stop: Sequence[str] | None = None,
user_id: str,
structured_output_enabled: bool,
structured_output: Optional[Mapping[str, Any]] = None,
structured_output: Mapping[str, Any] | None = None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials
@ -385,6 +408,7 @@ class LLMNode(BaseNode):
file_saver=file_saver,
file_outputs=file_outputs,
node_id=node_id,
reasoning_format=reasoning_format,
)
@staticmethod
@ -394,6 +418,7 @@ class LLMNode(BaseNode):
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
@ -401,6 +426,7 @@ class LLMNode(BaseNode):
invoke_result=invoke_result,
saver=file_saver,
file_outputs=file_outputs,
reasoning_format=reasoning_format,
)
yield event
return
@ -441,13 +467,66 @@ class LLMNode(BaseNode):
except OutputParserError as e:
raise LLMNodeError(f"Failed to parse structured output: {e}")
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
# Extract reasoning content from <think> tags in the main text
full_text = full_text_buffer.getvalue()
if reasoning_format == "tagged":
# Keep <think> tags in text for backward compatibility
clean_text = full_text
reasoning_content = ""
else:
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
yield ModelInvokeCompletedEvent(
# Use clean_text for separated mode, full_text for tagged mode
text=clean_text if reasoning_format == "separated" else full_text,
usage=usage,
finish_reason=finish_reason,
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
)
@staticmethod
def _image_file_to_markdown(file: "File", /):
text_chunk = f"![]({file.generate_url()})"
return text_chunk
@classmethod
def _split_reasoning(
cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged"
) -> tuple[str, str]:
"""
Split reasoning content from text based on reasoning_format strategy.
Args:
text: Full text that may contain <think> blocks
reasoning_format: Strategy for handling reasoning content
- "separated": Remove <think> tags and return clean text + reasoning_content field
- "tagged": Keep <think> tags in text, return empty reasoning_content
Returns:
tuple of (clean_text, reasoning_content)
"""
if reasoning_format == "tagged":
return text, ""
# Find all <think>...</think> blocks (case-insensitive)
matches = cls._THINK_PATTERN.findall(text)
# Extract reasoning content from all <think> blocks
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
# Remove all <think>...</think> blocks from original text
clean_text = cls._THINK_PATTERN.sub("", text)
# Clean up extra whitespace
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
# Separated mode: always return clean text and reasoning_content
return clean_text, reasoning_content or ""
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
@ -640,7 +719,7 @@ class LLMNode(BaseNode):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
if isinstance(prompt_template, list):
@ -748,7 +827,7 @@ class LLMNode(BaseNode):
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
@ -883,7 +962,7 @@ class LLMNode(BaseNode):
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: dict | None = None):
return {
"type": "llm",
"config": {
@ -911,7 +990,7 @@ class LLMNode(BaseNode):
def handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
@ -975,6 +1054,7 @@ class LLMNode(BaseNode):
invoke_result: LLMResult,
saver: LLMFileSaver,
file_outputs: list["File"],
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> ModelInvokeCompletedEvent:
buffer = io.StringIO()
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
@ -984,10 +1064,24 @@ class LLMNode(BaseNode):
):
buffer.write(text_part)
# Extract reasoning content from <think> tags in the main text
full_text = buffer.getvalue()
if reasoning_format == "tagged":
# Keep <think> tags in text for backward compatibility
clean_text = full_text
reasoning_content = ""
else:
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
return ModelInvokeCompletedEvent(
text=buffer.getvalue(),
# Use clean_text for separated mode, full_text for tagged mode
text=clean_text if reasoning_format == "separated" else full_text,
usage=invoke_result.usage,
finish_reason=None,
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
)
@staticmethod
@ -1156,7 +1250,7 @@ class LLMNode(BaseNode):
def _combine_message_content_with_role(
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
):
match role:
case PromptMessageRole.USER:
@ -1165,7 +1259,8 @@ def _combine_message_content_with_role(
return AssistantPromptMessage(content=contents)
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=contents)
raise NotImplementedError(f"Role {role} is not supported")
case _:
raise NotImplementedError(f"Role {role} is not supported")
def _render_jinja2_message(
@ -1261,7 +1356,7 @@ def _handle_memory_completion_mode(
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: Optional[str],
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
) -> Sequence[PromptMessage]:

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Annotated, Any, Literal, Optional
from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field
@ -12,9 +12,11 @@ _VALID_VAR_TYPE = frozenset(
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.OBJECT,
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
]
)
@ -33,7 +35,7 @@ class LoopVariableData(BaseModel):
label: str
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: Optional[Any | list[str]] = None
value: Any | list[str] | None = None
class LoopNodeData(BaseLoopNodeData):
@ -44,8 +46,8 @@ class LoopNodeData(BaseLoopNodeData):
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
outputs: Optional[Mapping[str, Any]] = None
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
outputs: Mapping[str, Any] | None = None
class LoopStartNodeData(BaseNodeData):
@ -70,7 +72,7 @@ class LoopState(BaseLoopState):
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Optional[Any] = None
current_output: Any | None = None
class MetaData(BaseLoopState.MetaData):
"""
@ -79,7 +81,7 @@ class LoopState(BaseLoopState):
loop_length: int
def get_last_output(self) -> Optional[Any]:
def get_last_output(self) -> Any | None:
"""
Get last output.
"""
@ -87,7 +89,7 @@ class LoopState(BaseLoopState):
return self.outputs[-1]
return None
def get_current_output(self) -> Optional[Any]:
def get_current_output(self) -> Any | None:
"""
Get current output.
"""

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -18,10 +18,10 @@ class LoopEndNode(BaseNode):
_node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopEndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -30,7 +30,7 @@ class LoopEndNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -3,7 +3,7 @@ import logging
import time
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from typing import TYPE_CHECKING, Any, Literal, cast
from configs import dify_config
from core.variables import (
@ -54,10 +54,10 @@ class LoopNode(BaseNode):
_node_data: LoopNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -66,7 +66,7 @@ class LoopNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -289,6 +289,8 @@ class LoopNode(BaseNode):
Returns:
dict: {'check_break_result': bool}
"""
condition_selectors = self._extract_selectors_from_conditions(break_conditions)
extended_selectors = {**loop_variable_selectors, **condition_selectors}
# Run workflow
rst = graph_engine.run()
current_index_variable = variable_pool.get([self.node_id, "index"])
@ -299,8 +301,8 @@ class LoopNode(BaseNode):
check_break_result = False
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id:
event.in_loop_id = self.node_id
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id: # ty: ignore [unresolved-attribute]
event.in_loop_id = self.node_id # ty: ignore [unresolved-attribute]
if (
isinstance(event, BaseNodeEvent)
@ -314,31 +316,30 @@ class LoopNode(BaseNode):
and event.node_type == NodeType.LOOP_END
and not isinstance(event, NodeRunStreamChunkEvent)
):
# Check if variables in break conditions exist and process conditions
# Allow loop internal variables to be used in break conditions
available_conditions = []
for condition in break_conditions:
variable = self.graph_runtime_state.variable_pool.get(condition.variable_selector)
if variable:
available_conditions.append(condition)
# Process conditions if at least one variable is available
if available_conditions:
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=available_conditions,
operator=logical_operator,
)
if check_break_result:
break
else:
check_break_result = True
check_break_result = True
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
break
if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
# Check if all variables in break conditions exist
exists_variable = False
for condition in break_conditions:
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector):
exists_variable = False
break
else:
exists_variable = True
if exists_variable:
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
if check_break_result:
break
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# Loop run failed
@ -400,21 +401,21 @@ class LoopNode(BaseNode):
else:
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
# Remove all nodes outputs from variable pool
for node_id in loop_graph.node_ids:
variable_pool.remove([node_id])
_outputs = {}
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
_outputs: dict[str, Segment | int | None] = {}
for loop_variable_key, loop_variable_selector in extended_selectors.items():
_loop_variable_segment = variable_pool.get(loop_variable_selector)
if _loop_variable_segment:
_outputs[loop_variable_key] = _loop_variable_segment.value
_outputs[loop_variable_key] = _loop_variable_segment
else:
_outputs[loop_variable_key] = None
_outputs["loop_round"] = current_index + 1
self._node_data.outputs = _outputs
# Remove all nodes outputs from variable pool
for node_id in loop_graph.node_ids:
variable_pool.remove([node_id])
if check_break_result:
return {"check_break_result": True}
@ -433,6 +434,13 @@ class LoopNode(BaseNode):
return {"check_break_result": False}
def _extract_selectors_from_conditions(self, conditions: list) -> dict[str, list[str]]:
return {
condition.variable_selector[1]: condition.variable_selector
for condition in conditions
if condition.variable_selector and len(condition.variable_selector) >= 2
}
def _handle_event_metadata(
self,
*,
@ -522,21 +530,33 @@ class LoopNode(BaseNode):
return variable_mapping
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
if var_type in ["array[string]", "array[number]", "array[object]"]:
if value and isinstance(value, str):
value = json.loads(value)
# TODO: Refactor for maintainability:
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
# 2. Consider moving this method to LoopVariableData class for better encapsulation
if not var_type.is_array_type() or var_type == SegmentType.ARRAY_BOOLEAN:
value = original_value
elif var_type in [
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
]:
if original_value and isinstance(original_value, str):
value = json.loads(original_value)
else:
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
value = []
else:
raise AssertionError("this statement should be unreachable.")
try:
return build_segment_with_type(var_type, value)
return build_segment_with_type(var_type, value=value)
except TypeMismatchError as type_exc:
# Attempt to parse the value as a JSON-encoded string, if applicable.
if not isinstance(value, str):
if not isinstance(original_value, str):
raise
try:
value = json.loads(value)
value = json.loads(original_value)
except ValueError:
raise type_exc
return build_segment_with_type(var_type, value)

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -18,10 +18,10 @@ class LoopStartNode(BaseNode):
_node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -30,7 +30,7 @@ class LoopStartNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,14 +1,46 @@
from typing import Any, Literal, Optional
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field, field_validator
from pydantic import (
BaseModel,
BeforeValidator,
Field,
field_validator,
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
_OLD_BOOL_TYPE_NAME = "bool"
_OLD_SELECT_TYPE_NAME = "select"
_VALID_PARAMETER_TYPES = frozenset(
[
SegmentType.STRING, # "string",
SegmentType.NUMBER, # "number",
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
_OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node
_OLD_SELECT_TYPE_NAME, # string type with enumeration choices.
]
)
class _ParameterConfigError(Exception):
pass
def _validate_type(parameter_type: str) -> SegmentType:
if not isinstance(parameter_type, str):
raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}")
if parameter_type not in _VALID_PARAMETER_TYPES:
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
if parameter_type == _OLD_BOOL_TYPE_NAME:
return SegmentType.BOOLEAN
elif parameter_type == _OLD_SELECT_TYPE_NAME:
return SegmentType.STRING
return SegmentType(parameter_type)
class ParameterConfig(BaseModel):
@ -17,8 +49,8 @@ class ParameterConfig(BaseModel):
"""
name: str
type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"]
options: Optional[list[str]] = None
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
options: list[str] | None = None
description: str
required: bool
@ -32,17 +64,20 @@ class ParameterConfig(BaseModel):
return str(value)
def is_array_type(self) -> bool:
return self.type in ("array[string]", "array[number]", "array[object]")
return self.type.is_array_type()
def element_type(self) -> Literal["string", "number", "object"]:
if self.type == "array[number]":
return "number"
elif self.type == "array[string]":
return "string"
elif self.type == "array[object]":
return "object"
else:
raise _ParameterConfigError(f"{self.type} is not array type.")
def element_type(self) -> SegmentType:
"""Return the element type of the parameter.
Raises a ValueError if the parameter's type is not an array type.
"""
element_type = self.type.element_type()
# At this point, self.type is guaranteed to be one of `ARRAY_STRING`,
# `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`.
#
# See: _VALID_PARAMETER_TYPES for reference.
assert element_type is not None, f"the element type should not be None, {self.type=}"
return element_type
class ParameterExtractorNodeData(BaseNodeData):
@ -53,8 +88,8 @@ class ParameterExtractorNodeData(BaseNodeData):
model: ModelConfig
query: list[str]
parameters: list[ParameterConfig]
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None
instruction: str | None = None
memory: MemoryConfig | None = None
reasoning_mode: Literal["function_call", "prompt"]
vision: VisionConfig = Field(default_factory=VisionConfig)
@ -63,7 +98,7 @@ class ParameterExtractorNodeData(BaseNodeData):
def set_reasoning_mode(cls, v) -> str:
return v or "function_call"
def get_parameter_json_schema(self) -> dict:
def get_parameter_json_schema(self):
"""
Get parameter json schema.
@ -74,16 +109,18 @@ class ParameterExtractorNodeData(BaseNodeData):
for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description}
if parameter.type in {"string", "select"}:
if parameter.type == SegmentType.STRING:
parameter_schema["type"] = "string"
elif parameter.type.startswith("array"):
elif parameter.type.is_array_type():
parameter_schema["type"] = "array"
nested_type = parameter.type[6:-1]
parameter_schema["items"] = {"type": nested_type}
element_type = parameter.type.element_type()
if element_type is None:
raise AssertionError("element type should not be None.")
parameter_schema["items"] = {"type": element_type.value}
else:
parameter_schema["type"] = parameter.type
if parameter.type == "select":
if parameter.options:
parameter_schema["enum"] = parameter.options
parameters["properties"][parameter.name] = parameter_schema

View File

@ -1,3 +1,8 @@
from typing import Any
from core.variables.types import SegmentType
class ParameterExtractorNodeError(ValueError):
"""Base error for ParameterExtractorNode."""
@ -48,3 +53,23 @@ class InvalidArrayValueError(ParameterExtractorNodeError):
class InvalidModelModeError(ParameterExtractorNodeError):
"""Raised when the model mode is invalid."""
class InvalidValueTypeError(ParameterExtractorNodeError):
def __init__(
self,
/,
parameter_name: str,
expected_type: SegmentType,
actual_type: SegmentType | None,
value: Any,
):
message = (
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
)
super().__init__(message)
self.parameter_name = parameter_name
self.expected_type = expected_type
self.actual_type = actual_type
self.value = value

View File

@ -3,7 +3,7 @@ import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import File
@ -26,7 +26,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import SegmentType
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -39,22 +39,20 @@ from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData
from .exc import (
InvalidArrayValueError,
InvalidBoolValueError,
InvalidInvokeResultError,
InvalidModelModeError,
InvalidModelTypeError,
InvalidNumberOfParametersError,
InvalidNumberValueError,
InvalidSelectValueError,
InvalidStringValueError,
InvalidTextContentTypeError,
InvalidValueTypeError,
ModelSchemaNotFoundError,
ParameterExtractorNodeError,
RequiredParameterMissingError,
)
from .prompts import (
CHAT_EXAMPLE,
CHAT_GENERATE_JSON_PROMPT,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
COMPLETION_GENERATE_JSON_PROMPT,
FUNCTION_CALLING_EXTRACTOR_EXAMPLE,
@ -97,10 +95,10 @@ class ParameterExtractorNode(BaseNode):
_node_data: ParameterExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ParameterExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -109,7 +107,7 @@ class ParameterExtractorNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -118,11 +116,11 @@ class ParameterExtractorNode(BaseNode):
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
_model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
_model_instance: ModelInstance | None = None
_model_config: ModelConfigWithCredentialsEntity | None = None
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: dict | None = None):
return {
"model": {
"prompt_templates": {
@ -142,7 +140,7 @@ class ParameterExtractorNode(BaseNode):
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self._node_data)
node_data = self._node_data
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""
@ -297,7 +295,7 @@ class ParameterExtractorNode(BaseNode):
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
stop: list[str],
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params,
@ -332,9 +330,9 @@ class ParameterExtractorNode(BaseNode):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
"""
Generate function call prompt.
@ -414,9 +412,9 @@ class ParameterExtractorNode(BaseNode):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate prompt engineering prompt.
@ -452,9 +450,9 @@ class ParameterExtractorNode(BaseNode):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate completion prompt.
@ -486,9 +484,9 @@ class ParameterExtractorNode(BaseNode):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate chat prompt.
@ -548,10 +546,7 @@ class ParameterExtractorNode(BaseNode):
return prompt_messages
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
"""
Validate result.
"""
def _validate_result(self, data: ParameterExtractorNodeData, result: dict):
if len(data.parameters) != len(result):
raise InvalidNumberOfParametersError("Invalid number of parameters")
@ -559,105 +554,110 @@ class ParameterExtractorNode(BaseNode):
if parameter.required and parameter.name not in result:
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
if parameter.type.startswith("array"):
parameters = result.get(parameter.name)
if not isinstance(parameters, list):
raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
nested_type = parameter.type[6:-1]
for item in parameters:
if nested_type == "number" and not isinstance(item, int | float):
raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
if nested_type == "string" and not isinstance(item, str):
raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
if nested_type == "object" and not isinstance(item, dict):
raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
param_value = result.get(parameter.name)
if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL):
inferred_type = SegmentType.infer_segment_type(param_value)
raise InvalidValueTypeError(
parameter_name=parameter.name,
expected_type=parameter.type,
actual_type=inferred_type,
value=param_value,
)
if parameter.type == SegmentType.STRING and parameter.options:
if param_value not in parameter.options:
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
return result
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
@staticmethod
def _transform_number(value: int | float | str | bool) -> int | float | None:
"""
Attempts to transform the input into an integer or float.
Returns:
int or float: The transformed number if the conversion is successful.
None: If the transformation fails.
Note:
Boolean values `True` and `False` are converted to integers `1` and `0`, respectively.
This behavior ensures compatibility with existing workflows that may use boolean types as integers.
"""
if isinstance(value, bool):
return int(value)
elif isinstance(value, (int, float)):
return value
elif not isinstance(value, str):
return None
if "." in value:
try:
return float(value)
except ValueError:
return None
else:
try:
return int(value)
except ValueError:
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
"""
Transform result into standard format.
"""
transformed_result = {}
transformed_result: dict[str, Any] = {}
for parameter in data.parameters:
if parameter.name in result:
param_value = result[parameter.name]
# transform value
if parameter.type == "number":
if isinstance(result[parameter.name], int | float):
transformed_result[parameter.name] = result[parameter.name]
elif isinstance(result[parameter.name], str):
try:
if "." in result[parameter.name]:
result[parameter.name] = float(result[parameter.name])
else:
result[parameter.name] = int(result[parameter.name])
except ValueError:
pass
else:
pass
# TODO: bool is not supported in the current version
# elif parameter.type == 'bool':
# if isinstance(result[parameter.name], bool):
# transformed_result[parameter.name] = bool(result[parameter.name])
# elif isinstance(result[parameter.name], str):
# if result[parameter.name].lower() in ['true', 'false']:
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
# elif isinstance(result[parameter.name], int):
# transformed_result[parameter.name] = bool(result[parameter.name])
elif parameter.type in {"string", "select"}:
if isinstance(result[parameter.name], str):
transformed_result[parameter.name] = result[parameter.name]
if parameter.type == SegmentType.NUMBER:
transformed = self._transform_number(param_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
elif parameter.type == SegmentType.BOOLEAN:
if isinstance(result[parameter.name], (bool, int)):
transformed_result[parameter.name] = bool(result[parameter.name])
# elif isinstance(result[parameter.name], str):
# if result[parameter.name].lower() in ["true", "false"]:
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true")
elif parameter.type == SegmentType.STRING:
if isinstance(param_value, str):
transformed_result[parameter.name] = param_value
elif parameter.is_array_type():
if isinstance(result[parameter.name], list):
if isinstance(param_value, list):
nested_type = parameter.element_type()
assert nested_type is not None
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
transformed_result[parameter.name] = segment_value
for item in result[parameter.name]:
if nested_type == "number":
if isinstance(item, int | float):
segment_value.value.append(item)
elif isinstance(item, str):
try:
if "." in item:
segment_value.value.append(float(item))
else:
segment_value.value.append(int(item))
except ValueError:
pass
elif nested_type == "string":
for item in param_value:
if nested_type == SegmentType.NUMBER:
transformed = self._transform_number(item)
if transformed is not None:
segment_value.value.append(transformed)
elif nested_type == SegmentType.STRING:
if isinstance(item, str):
segment_value.value.append(item)
elif nested_type == "object":
elif nested_type == SegmentType.OBJECT:
if isinstance(item, dict):
segment_value.value.append(item)
elif nested_type == SegmentType.BOOLEAN:
if isinstance(item, bool):
segment_value.value.append(item)
if parameter.name not in transformed_result:
if parameter.type == "number":
transformed_result[parameter.name] = 0
elif parameter.type == "bool":
transformed_result[parameter.name] = False
elif parameter.type in {"string", "select"}:
transformed_result[parameter.name] = ""
elif parameter.type.startswith("array"):
if parameter.type.is_array_type():
transformed_result[parameter.name] = build_segment_with_type(
segment_type=SegmentType(parameter.type), value=[]
)
elif parameter.type in (SegmentType.STRING, SegmentType.SECRET):
transformed_result[parameter.name] = ""
elif parameter.type == SegmentType.NUMBER:
transformed_result[parameter.name] = 0
elif parameter.type == SegmentType.BOOLEAN:
transformed_result[parameter.name] = False
else:
raise AssertionError("this statement should be unreachable.")
return transformed_result
def _extract_complete_json_response(self, result: str) -> Optional[dict]:
def _extract_complete_json_response(self, result: str) -> dict | None:
"""
Extract complete json response.
"""
@ -672,7 +672,7 @@ class ParameterExtractorNode(BaseNode):
logger.info("extra error: %s", result)
return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
"""
Extract json from tool call.
"""
@ -691,7 +691,7 @@ class ParameterExtractorNode(BaseNode):
logger.info("extra error: %s", result)
return None
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
def _generate_default_result(self, data: ParameterExtractorNodeData):
"""
Generate default result.
"""
@ -711,7 +711,7 @@ class ParameterExtractorNode(BaseNode):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode)
@ -738,7 +738,7 @@ class ParameterExtractorNode(BaseNode):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
@ -753,7 +753,7 @@ class ParameterExtractorNode(BaseNode):
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]
@ -774,7 +774,7 @@ class ParameterExtractorNode(BaseNode):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str],
context: str | None,
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)

View File

@ -1,5 +1,3 @@
from typing import Optional
from pydantic import BaseModel, Field
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
@ -16,8 +14,8 @@ class QuestionClassifierNodeData(BaseNodeData):
query_variable_selector: list[str]
model: ModelConfig
classes: list[ClassConfig]
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None
instruction: str | None = None
memory: MemoryConfig | None = None
vision: VisionConfig = Field(default_factory=VisionConfig)
@property

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@ -59,11 +59,11 @@ class QuestionClassifierNode(BaseNode):
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
previous_node_id: str | None = None,
thread_pool_id: str | None = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
):
super().__init__(
id=id,
config=config,
@ -83,10 +83,10 @@ class QuestionClassifierNode(BaseNode):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = QuestionClassifierNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -95,7 +95,7 @@ class QuestionClassifierNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -109,7 +109,7 @@ class QuestionClassifierNode(BaseNode):
return "1"
def _run(self):
node_data = cast(QuestionClassifierNodeData, self._node_data)
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
@ -275,7 +275,7 @@ class QuestionClassifierNode(BaseNode):
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: dict | None = None):
"""
Get default config of node.
:param filters: filter by node config parameters.
@ -288,7 +288,7 @@ class QuestionClassifierNode(BaseNode):
node_data: QuestionClassifierNodeData,
query: str,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str],
context: str | None,
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
@ -331,7 +331,7 @@ class QuestionClassifierNode(BaseNode):
self,
node_data: QuestionClassifierNodeData,
query: str,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
@ -15,10 +15,10 @@ class StartNode(BaseNode):
_node_data: StartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = StartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -27,7 +27,7 @@ class StartNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,6 +1,6 @@
import os
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from typing import Any
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.entities.node_entities import NodeRunResult
@ -18,10 +18,10 @@ class TemplateTransformNode(BaseNode):
_node_data: TemplateTransformNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = TemplateTransformNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -30,7 +30,7 @@ class TemplateTransformNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -40,7 +40,7 @@ class TemplateTransformNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: dict | None = None):
"""
Get default config of node.
:param filters: filter by node config parameters.

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -45,7 +45,7 @@ class ToolNode(BaseNode):
_node_data: ToolNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ToolNodeData.model_validate(data)
@classmethod
@ -57,7 +57,7 @@ class ToolNode(BaseNode):
Run the tool node
"""
node_data = cast(ToolNodeData, self._node_data)
node_data = self._node_data
# fetch tool icon
tool_info = {
@ -439,7 +439,7 @@ class ToolNode(BaseNode):
return result
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -448,7 +448,7 @@ class ToolNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,3 @@
from typing import Optional
from pydantic import BaseModel
from core.variables.types import SegmentType
@ -33,4 +31,4 @@ class VariableAssignerNodeData(BaseNodeData):
type: str = "variable-assigner"
output_type: str
variables: list[list[str]]
advanced_settings: Optional[AdvancedSettings] = None
advanced_settings: AdvancedSettings | None = None

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.variables.segments import Segment
from core.workflow.entities.node_entities import NodeRunResult
@ -15,10 +15,10 @@ class VariableAggregatorNode(BaseNode):
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -27,7 +27,7 @@ class VariableAggregatorNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -16,7 +16,7 @@ class UpdatedVariable(BaseModel):
name: str
selector: Sequence[str]
value_type: SegmentType
new_value: Any
new_value: Any = None
_T = TypeVar("_T", bound=MutableMapping[str, Any])
@ -25,7 +25,7 @@ _T = TypeVar("_T", bound=MutableMapping[str, Any])
def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
if len(selector) < SELECTORS_LENGTH:
raise Exception("selector too short")
node_id, var_name = selector[:2]
_, var_name = selector[:2]
return UpdatedVariable(
name=var_name,
selector=list(selector[:2]),

View File

@ -11,7 +11,7 @@ from .exc import VariableOperatorNodeError
class ConversationVariableUpdaterImpl:
_engine: Engine | None
def __init__(self, engine: Engine | None = None) -> None:
def __init__(self, engine: Engine | None = None):
self._engine = engine
def _get_engine(self) -> Engine:

View File

@ -1,7 +1,8 @@
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
from typing import TYPE_CHECKING, Any, TypeAlias
from core.variables import SegmentType, Variable
from core.variables.segments import BooleanSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
@ -29,10 +30,10 @@ class VariableAssignerNode(BaseNode):
_node_data: VariableAssignerData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -41,7 +42,7 @@ class VariableAssignerNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -57,10 +58,10 @@ class VariableAssignerNode(BaseNode):
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
previous_node_id: str | None = None,
thread_pool_id: str | None = None,
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
) -> None:
):
super().__init__(
id=id,
config=config,
@ -158,8 +159,8 @@ class VariableAssignerNode(BaseNode):
def get_zero_value(t: SegmentType):
# TODO(QuantumGhost): this should be a method of `SegmentType`.
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return variable_factory.build_segment([])
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN:
return variable_factory.build_segment_with_type(t, [])
case SegmentType.OBJECT:
return variable_factory.build_segment({})
case SegmentType.STRING:
@ -170,5 +171,7 @@ def get_zero_value(t: SegmentType):
return variable_factory.build_segment(0.0)
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case SegmentType.BOOLEAN:
return BooleanSegment(value=False)
case _:
raise VariableOperatorNodeError(f"unsupported variable type: {t}")

View File

@ -4,9 +4,11 @@ from core.variables import SegmentType
EMPTY_VALUE_MAPPING = {
SegmentType.STRING: "",
SegmentType.NUMBER: 0,
SegmentType.BOOLEAN: False,
SegmentType.OBJECT: {},
SegmentType.ARRAY_ANY: [],
SegmentType.ARRAY_STRING: [],
SegmentType.ARRAY_NUMBER: [],
SegmentType.ARRAY_OBJECT: [],
SegmentType.ARRAY_BOOLEAN: [],
}

View File

@ -32,5 +32,5 @@ class ConversationIDNotFoundError(VariableOperatorNodeError):
class InvalidDataError(VariableOperatorNodeError):
def __init__(self, message: str) -> None:
def __init__(self, message: str):
super().__init__(message)

View File

@ -16,28 +16,15 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
SegmentType.NUMBER,
SegmentType.INTEGER,
SegmentType.FLOAT,
SegmentType.BOOLEAN,
}
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
# Only number variable can be added, subtracted, multiplied or divided
return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
case Operation.APPEND | Operation.EXTEND:
case Operation.APPEND | Operation.EXTEND | Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
# Only array variable can be appended or extended
return variable_type in {
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
case Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
# Only array variable can have elements removed
return variable_type in {
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
return variable_type.is_array_type()
case _:
return False
@ -50,7 +37,7 @@ def is_variable_input_supported(*, operation: Operation):
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
match variable_type:
case SegmentType.STRING | SegmentType.OBJECT:
case SegmentType.STRING | SegmentType.OBJECT | SegmentType.BOOLEAN:
return operation in {Operation.OVER_WRITE, Operation.SET}
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return operation in {
@ -72,6 +59,9 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
case SegmentType.STRING:
return isinstance(value, str)
case SegmentType.BOOLEAN:
return isinstance(value, bool)
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
if not isinstance(value, int | float):
return False
@ -91,6 +81,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
return isinstance(value, int | float)
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
return isinstance(value, dict)
case SegmentType.ARRAY_BOOLEAN if operation == Operation.APPEND:
return isinstance(value, bool)
# Array & Extend / Overwrite
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
@ -101,6 +93,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
case SegmentType.ARRAY_BOOLEAN if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, bool) for item in value)
case _:
return False

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
@ -58,10 +58,10 @@ class VariableAssignerNode(BaseNode):
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -70,7 +70,7 @@ class VariableAssignerNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]: