mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
Merge branch 'main' into feat/memory-orchestration-be
# Conflicts: # api/core/app/apps/advanced_chat/app_runner.py # api/core/prompt/entities/advanced_prompt_entities.py # api/core/variables/segments.py
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
@ -66,10 +66,10 @@ class AgentNode(BaseNode):
|
||||
_node_type = NodeType.AGENT
|
||||
_node_data: AgentNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AgentNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -78,7 +78,7 @@ class AgentNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -153,7 +153,7 @@ class AgentNode(BaseNode):
|
||||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
|
||||
"agent_strategy": self._node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
@ -394,15 +394,14 @@ class AgentNode(BaseNode):
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}"
|
||||
== cast(AgentNodeData, self._node_data).agent_strategy_provider_name
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
|
||||
# get conversation id
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from enum import Enum, StrEnum
|
||||
from enum import IntEnum, StrEnum, auto
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -25,9 +25,9 @@ class AgentNodeData(BaseNodeData):
|
||||
agent_parameters: dict[str, AgentInput]
|
||||
|
||||
|
||||
class ParamsAutoGenerated(Enum):
|
||||
CLOSE = 0
|
||||
OPEN = 1
|
||||
class ParamsAutoGenerated(IntEnum):
|
||||
CLOSE = auto()
|
||||
OPEN = auto()
|
||||
|
||||
|
||||
class AgentOldVersionModelFeatures(StrEnum):
|
||||
@ -38,8 +38,8 @@ class AgentOldVersionModelFeatures(StrEnum):
|
||||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = "vision"
|
||||
VISION = auto()
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
DOCUMENT = auto()
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class AgentNodeError(Exception):
|
||||
"""Base exception for all agent node errors."""
|
||||
|
||||
@ -12,7 +9,7 @@ class AgentNodeError(Exception):
|
||||
class AgentStrategyError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent strategy."""
|
||||
|
||||
def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
|
||||
def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None):
|
||||
self.strategy_name = strategy_name
|
||||
self.provider_name = provider_name
|
||||
super().__init__(message)
|
||||
@ -21,7 +18,7 @@ class AgentStrategyError(AgentNodeError):
|
||||
class AgentStrategyNotFoundError(AgentStrategyError):
|
||||
"""Exception raised when the specified agent strategy is not found."""
|
||||
|
||||
def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
|
||||
def __init__(self, strategy_name: str, provider_name: str | None = None):
|
||||
super().__init__(
|
||||
f"Agent strategy '{strategy_name}' not found"
|
||||
+ (f" for provider '{provider_name}'" if provider_name else ""),
|
||||
@ -33,7 +30,7 @@ class AgentStrategyNotFoundError(AgentStrategyError):
|
||||
class AgentInvocationError(AgentNodeError):
|
||||
"""Exception raised when there's an error invoking the agent."""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
def __init__(self, message: str, original_error: Exception | None = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
@ -41,7 +38,7 @@ class AgentInvocationError(AgentNodeError):
|
||||
class AgentParameterError(AgentNodeError):
|
||||
"""Exception raised when there's an error with agent parameters."""
|
||||
|
||||
def __init__(self, message: str, parameter_name: Optional[str] = None):
|
||||
def __init__(self, message: str, parameter_name: str | None = None):
|
||||
self.parameter_name = parameter_name
|
||||
super().__init__(message)
|
||||
|
||||
@ -49,7 +46,7 @@ class AgentParameterError(AgentNodeError):
|
||||
class AgentVariableError(AgentNodeError):
|
||||
"""Exception raised when there's an error with variables in the agent node."""
|
||||
|
||||
def __init__(self, message: str, variable_name: Optional[str] = None):
|
||||
def __init__(self, message: str, variable_name: str | None = None):
|
||||
self.variable_name = variable_name
|
||||
super().__init__(message)
|
||||
|
||||
@ -71,7 +68,7 @@ class AgentInputTypeError(AgentNodeError):
|
||||
class ToolFileError(AgentNodeError):
|
||||
"""Exception raised when there's an error with a tool file."""
|
||||
|
||||
def __init__(self, message: str, file_id: Optional[str] = None):
|
||||
def __init__(self, message: str, file_id: str | None = None):
|
||||
self.file_id = file_id
|
||||
super().__init__(message)
|
||||
|
||||
@ -86,7 +83,7 @@ class ToolFileNotFoundError(ToolFileError):
|
||||
class AgentMessageTransformError(AgentNodeError):
|
||||
"""Exception raised when there's an error transforming agent messages."""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
def __init__(self, message: str, original_error: Exception | None = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
@ -94,7 +91,7 @@ class AgentMessageTransformError(AgentNodeError):
|
||||
class AgentModelError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the model used by the agent."""
|
||||
|
||||
def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
|
||||
def __init__(self, message: str, model_name: str | None = None, provider: str | None = None):
|
||||
self.model_name = model_name
|
||||
self.provider = provider
|
||||
super().__init__(message)
|
||||
@ -103,7 +100,7 @@ class AgentModelError(AgentNodeError):
|
||||
class AgentMemoryError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent's memory."""
|
||||
|
||||
def __init__(self, message: str, conversation_id: Optional[str] = None):
|
||||
def __init__(self, message: str, conversation_id: str | None = None):
|
||||
self.conversation_id = conversation_id
|
||||
super().__init__(message)
|
||||
|
||||
@ -114,9 +111,9 @@ class AgentVariableTypeError(AgentNodeError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
variable_name: Optional[str] = None,
|
||||
expected_type: Optional[str] = None,
|
||||
actual_type: Optional[str] = None,
|
||||
variable_name: str | None = None,
|
||||
expected_type: str | None = None,
|
||||
actual_type: str | None = None,
|
||||
):
|
||||
self.variable_name = variable_name
|
||||
self.expected_type = expected_type
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.variables import ArrayFileSegment, FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
@ -22,10 +22,10 @@ class AnswerNode(BaseNode):
|
||||
|
||||
_node_data: AnswerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AnswerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -34,7 +34,7 @@ class AnswerNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -134,7 +134,7 @@ class AnswerStreamGeneratorRouter:
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
answer_dependencies: dict[str, list[str]],
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Recursive fetch answer dependencies
|
||||
:param current_node_id: current node id
|
||||
|
||||
@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnswerStreamProcessor(StreamProcessor):
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool):
|
||||
super().__init__(graph, variable_pool)
|
||||
self.generate_routes = graph.answer_stream_generate_routes
|
||||
self.route_position = {}
|
||||
@ -52,12 +52,12 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
yield event
|
||||
elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent):
|
||||
yield event
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: # ty: ignore [unresolved-attribute]
|
||||
# update self.route_position after all stream event finished
|
||||
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
|
||||
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: # ty: ignore [unresolved-attribute]
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] # ty: ignore [unresolved-attribute]
|
||||
|
||||
self._remove_unreachable_nodes(event)
|
||||
|
||||
@ -66,9 +66,9 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
else:
|
||||
yield event
|
||||
|
||||
def reset(self) -> None:
|
||||
def reset(self):
|
||||
self.route_position = {}
|
||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
||||
for answer_node_id, _ in self.generate_routes.answer_generate_route.items():
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
@ -149,9 +149,6 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
return []
|
||||
|
||||
stream_output_value_selector = event.from_variable_selector
|
||||
if not stream_output_value_selector:
|
||||
return []
|
||||
|
||||
stream_out_answer_node_ids = []
|
||||
for answer_node_id, route_position in self.route_position.items():
|
||||
if answer_node_id not in self.rest_node_ids:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
|
||||
@ -11,7 +10,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamProcessor(ABC):
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool):
|
||||
self.graph = graph
|
||||
self.variable_pool = variable_pool
|
||||
self.rest_node_ids = graph.node_ids.copy()
|
||||
@ -20,7 +19,7 @@ class StreamProcessor(ABC):
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent):
|
||||
finished_node_id = event.route_node_state.node_id
|
||||
if finished_node_id not in self.rest_node_ids:
|
||||
return
|
||||
@ -72,7 +71,7 @@ class StreamProcessor(ABC):
|
||||
for node_id in unreachable_first_node_ids:
|
||||
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
||||
|
||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
|
||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: str | None = None) -> list[str]:
|
||||
if node_id not in self.rest_node_ids:
|
||||
self.rest_node_ids.append(node_id)
|
||||
node_ids = []
|
||||
@ -89,7 +88,7 @@ class StreamProcessor(ABC):
|
||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
|
||||
return node_ids
|
||||
|
||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]):
|
||||
"""
|
||||
remove target node ids until merge
|
||||
"""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -19,9 +19,9 @@ class GenerateRouteChunk(BaseModel):
|
||||
Generate Route Chunk.
|
||||
"""
|
||||
|
||||
class ChunkType(Enum):
|
||||
VAR = "var"
|
||||
TEXT = "text"
|
||||
class ChunkType(StrEnum):
|
||||
VAR = auto()
|
||||
TEXT = auto()
|
||||
|
||||
type: ChunkType = Field(..., description="generate route chunk type")
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
from abc import ABC
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@ -23,12 +23,12 @@ NumberType = Union[int, float]
|
||||
|
||||
|
||||
class DefaultValue(BaseModel):
|
||||
value: Any
|
||||
value: Any = None
|
||||
type: DefaultValueType
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str) -> Any:
|
||||
def _parse_json(value: str):
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
@ -121,10 +121,10 @@ class RetryConfig(BaseModel):
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
desc: str | None = None
|
||||
version: str = "1"
|
||||
error_strategy: Optional[ErrorStrategy] = None
|
||||
default_value: Optional[list[DefaultValue]] = None
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
@ -135,7 +135,7 @@ class BaseNodeData(ABC, BaseModel):
|
||||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: Optional[str] = None
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
class BaseIterationState(BaseModel):
|
||||
@ -150,7 +150,7 @@ class BaseIterationState(BaseModel):
|
||||
|
||||
|
||||
class BaseLoopNodeData(BaseNodeData):
|
||||
start_node_id: Optional[str] = None
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
class BaseLoopState(BaseModel):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Union
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
@ -26,9 +26,9 @@ class BaseNode:
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
previous_node_id: str | None = None,
|
||||
thread_pool_id: str | None = None,
|
||||
):
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
@ -51,7 +51,7 @@ class BaseNode:
|
||||
self.node_id = node_id
|
||||
|
||||
@abstractmethod
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
||||
def init_node_data(self, data: Mapping[str, Any]): ...
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
@ -141,7 +141,7 @@ class BaseNode:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: dict | None = None):
|
||||
return {}
|
||||
|
||||
@property
|
||||
@ -170,7 +170,7 @@ class BaseNode:
|
||||
# to BaseNodeData properties in a type-safe way
|
||||
|
||||
@abstractmethod
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
...
|
||||
|
||||
@ -185,7 +185,7 @@ class BaseNode:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
"""Get the node description."""
|
||||
...
|
||||
|
||||
@ -201,7 +201,7 @@ class BaseNode:
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
def error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
return self._get_error_strategy()
|
||||
|
||||
@ -216,7 +216,7 @@ class BaseNode:
|
||||
return self._get_title()
|
||||
|
||||
@property
|
||||
def description(self) -> Optional[str]:
|
||||
def description(self) -> str | None:
|
||||
"""Get the node description."""
|
||||
return self._get_description()
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
@ -8,6 +8,7 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
@ -27,10 +28,10 @@ class CodeNode(BaseNode):
|
||||
|
||||
_node_data: CodeNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = CodeNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -39,7 +40,7 @@ class CodeNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -49,7 +50,7 @@ class CodeNode(BaseNode):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: dict | None = None):
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
@ -119,6 +120,14 @@ class CodeNode(BaseNode):
|
||||
|
||||
return value.replace("\x00", "")
|
||||
|
||||
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, bool):
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a boolean")
|
||||
|
||||
return value
|
||||
|
||||
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
|
||||
"""
|
||||
Check number
|
||||
@ -152,7 +161,7 @@ class CodeNode(BaseNode):
|
||||
def _transform_result(
|
||||
self,
|
||||
result: Mapping[str, Any],
|
||||
output_schema: Optional[dict[str, CodeNodeData.Output]],
|
||||
output_schema: dict[str, CodeNodeData.Output] | None,
|
||||
prefix: str = "",
|
||||
depth: int = 1,
|
||||
):
|
||||
@ -173,6 +182,8 @@ class CodeNode(BaseNode):
|
||||
prefix=f"{prefix}.{output_name}" if prefix else output_name,
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif isinstance(output_value, bool):
|
||||
self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name)
|
||||
elif isinstance(output_value, int | float):
|
||||
self._check_number(
|
||||
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
|
||||
@ -232,7 +243,7 @@ class CodeNode(BaseNode):
|
||||
if output_name not in result:
|
||||
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
|
||||
|
||||
if output_config.type == "object":
|
||||
if output_config.type == SegmentType.OBJECT:
|
||||
# check if output is object
|
||||
if not isinstance(result.get(output_name), dict):
|
||||
if result[output_name] is None:
|
||||
@ -249,18 +260,28 @@ class CodeNode(BaseNode):
|
||||
prefix=f"{prefix}.{output_name}",
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif output_config.type == "number":
|
||||
elif output_config.type == SegmentType.NUMBER:
|
||||
# check if number available
|
||||
transformed_result[output_name] = self._check_number(
|
||||
value=result[output_name], variable=f"{prefix}{dot}{output_name}"
|
||||
)
|
||||
elif output_config.type == "string":
|
||||
checked = self._check_number(value=result[output_name], variable=f"{prefix}{dot}{output_name}")
|
||||
# If the output is a boolean and the output schema specifies a NUMBER type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
transformed_result[output_name] = self._convert_boolean_to_int(checked)
|
||||
|
||||
elif output_config.type == SegmentType.STRING:
|
||||
# check if string available
|
||||
transformed_result[output_name] = self._check_string(
|
||||
value=result[output_name],
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == "array[number]":
|
||||
elif output_config.type == SegmentType.BOOLEAN:
|
||||
transformed_result[output_name] = self._check_boolean(
|
||||
value=result[output_name],
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == SegmentType.ARRAY_NUMBER:
|
||||
# check if array of number available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
@ -278,10 +299,17 @@ class CodeNode(BaseNode):
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
# If the element is a boolean and the output schema specifies a `array[number]` type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
self._convert_boolean_to_int(
|
||||
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]"),
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == "array[string]":
|
||||
elif output_config.type == SegmentType.ARRAY_STRING:
|
||||
# check if array of string available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
@ -302,7 +330,7 @@ class CodeNode(BaseNode):
|
||||
self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == "array[object]":
|
||||
elif output_config.type == SegmentType.ARRAY_OBJECT:
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
@ -340,6 +368,22 @@ class CodeNode(BaseNode):
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
transformed_result[output_name] = [
|
||||
self._check_boolean(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
|
||||
else:
|
||||
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
|
||||
|
||||
@ -374,3 +418,16 @@ class CodeNode(BaseNode):
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
@staticmethod
|
||||
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
|
||||
"""This function convert boolean to integers when the output schema specifies a NUMBER type.
|
||||
|
||||
This ensures compatibility with existing workflows that may use
|
||||
`True` and `False` as values for NUMBER type outputs.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
return value
|
||||
|
||||
@ -1,11 +1,31 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
|
||||
[
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _validate_type(segment_type: SegmentType) -> SegmentType:
|
||||
if segment_type not in _ALLOWED_OUTPUT_FROM_CODE:
|
||||
raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}")
|
||||
return segment_type
|
||||
|
||||
|
||||
class CodeNodeData(BaseNodeData):
|
||||
"""
|
||||
@ -13,8 +33,8 @@ class CodeNodeData(BaseNodeData):
|
||||
"""
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
children: Optional[dict[str, "CodeNodeData.Output"]] = None
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: dict[str, "CodeNodeData.Output"] | None = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
name: str
|
||||
@ -24,4 +44,4 @@ class CodeNodeData(BaseNodeData):
|
||||
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
|
||||
code: str
|
||||
outputs: dict[str, Output]
|
||||
dependencies: Optional[list[Dependency]] = None
|
||||
dependencies: list[Dependency] | None = None
|
||||
|
||||
@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any
|
||||
|
||||
import chardet
|
||||
import docx
|
||||
@ -47,10 +47,10 @@ class DocumentExtractorNode(BaseNode):
|
||||
|
||||
_node_data: DocumentExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = DocumentExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -59,7 +59,7 @@ class DocumentExtractorNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -302,12 +302,12 @@ def _extract_text_from_yaml(file_content: bytes) -> str:
|
||||
encoding = "utf-8"
|
||||
|
||||
yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore"))
|
||||
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
|
||||
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
|
||||
except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e:
|
||||
# If decoding fails, try with utf-8 as last resort
|
||||
try:
|
||||
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore"))
|
||||
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
|
||||
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
|
||||
except (UnicodeDecodeError, yaml.YAMLError):
|
||||
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
|
||||
|
||||
@ -428,9 +428,9 @@ def _download_file_content(file: File) -> bytes:
|
||||
raise FileDownloadError("Missing URL for remote file")
|
||||
response = ssrf_proxy.get(file.remote_url)
|
||||
response.raise_for_status()
|
||||
return cast(bytes, response.content)
|
||||
return response.content
|
||||
else:
|
||||
return cast(bytes, file_manager.download(file))
|
||||
return file_manager.download(file)
|
||||
except Exception as e:
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
@ -515,14 +515,14 @@ def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
# Combine multi-line text in each cell into a single line
|
||||
df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore
|
||||
df = df.map(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x)
|
||||
|
||||
# Combine multi-line text in column names into a single line
|
||||
df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns])
|
||||
|
||||
# Manually construct the Markdown table
|
||||
markdown_table += _construct_markdown_table(df) + "\n\n"
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
continue
|
||||
return markdown_table
|
||||
except Exception as e:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
@ -14,10 +14,10 @@ class EndNode(BaseNode):
|
||||
|
||||
_node_data: EndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = EndNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -26,7 +26,7 @@ class EndNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -121,7 +121,7 @@ class EndStreamGeneratorRouter:
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
end_dependencies: dict[str, list[str]],
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Recursive fetch end dependencies
|
||||
:param current_node_id: current node id
|
||||
|
||||
@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndStreamProcessor(StreamProcessor):
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool):
|
||||
super().__init__(graph, variable_pool)
|
||||
self.end_stream_param = graph.end_stream_param
|
||||
self.route_position = {}
|
||||
@ -76,7 +76,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
else:
|
||||
yield event
|
||||
|
||||
def reset(self) -> None:
|
||||
def reset(self):
|
||||
self.route_position = {}
|
||||
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
|
||||
self.route_position[end_node_id] = 0
|
||||
|
||||
@ -30,6 +30,7 @@ class ModelInvokeCompletedEvent(BaseModel):
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
|
||||
|
||||
class RunRetryEvent(BaseModel):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import mimetypes
|
||||
from collections.abc import Sequence
|
||||
from email.message import Message
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
@ -18,7 +18,7 @@ class HttpRequestNodeAuthorizationConfig(BaseModel):
|
||||
|
||||
class HttpRequestNodeAuthorization(BaseModel):
|
||||
type: Literal["no-auth", "api-key"]
|
||||
config: Optional[HttpRequestNodeAuthorizationConfig] = None
|
||||
config: HttpRequestNodeAuthorizationConfig | None = None
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
@ -88,9 +88,9 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
authorization: HttpRequestNodeAuthorization
|
||||
headers: str
|
||||
params: str
|
||||
body: Optional[HttpRequestNodeBody] = None
|
||||
timeout: Optional[HttpRequestNodeTimeout] = None
|
||||
ssl_verify: Optional[bool] = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
body: HttpRequestNodeBody | None = None
|
||||
timeout: HttpRequestNodeTimeout | None = None
|
||||
ssl_verify: bool | None = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
|
||||
|
||||
class Response:
|
||||
@ -183,7 +183,7 @@ class Response:
|
||||
return f"{(self.size / 1024 / 1024):.2f} MB"
|
||||
|
||||
@property
|
||||
def parsed_content_disposition(self) -> Optional[Message]:
|
||||
def parsed_content_disposition(self) -> Message | None:
|
||||
content_disposition = self.headers.get("content-disposition", "")
|
||||
if content_disposition:
|
||||
msg = Message()
|
||||
|
||||
@ -329,22 +329,16 @@ class Executor:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
if self.method not in {
|
||||
"get",
|
||||
"head",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
}:
|
||||
_METHOD_MAP = {
|
||||
"get": ssrf_proxy.get,
|
||||
"head": ssrf_proxy.head,
|
||||
"post": ssrf_proxy.post,
|
||||
"put": ssrf_proxy.put,
|
||||
"delete": ssrf_proxy.delete,
|
||||
"patch": ssrf_proxy.patch,
|
||||
}
|
||||
method_lc = self.method.lower()
|
||||
if method_lc not in _METHOD_MAP:
|
||||
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
||||
|
||||
request_args = {
|
||||
@ -362,11 +356,11 @@ class Executor:
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
|
||||
response: httpx.Response = _METHOD_MAP[method_lc](**request_args)
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
raise HttpRequestNodeError(str(e)) from e
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
return response # type: ignore
|
||||
return response
|
||||
|
||||
def invoke(self) -> Response:
|
||||
# assemble headers
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod
|
||||
@ -38,10 +38,10 @@ class HttpRequestNode(BaseNode):
|
||||
|
||||
_node_data: HttpRequestNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = HttpRequestNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -50,7 +50,7 @@ class HttpRequestNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -60,7 +60,7 @@ class HttpRequestNode(BaseNode):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
|
||||
def get_default_config(cls, filters: dict[str, Any] | None = None):
|
||||
return {
|
||||
"type": "http-request",
|
||||
"config": {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -20,7 +20,7 @@ class IfElseNodeData(BaseNodeData):
|
||||
logical_operator: Literal["and", "or"]
|
||||
conditions: list[Condition]
|
||||
|
||||
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
||||
|
||||
cases: Optional[list[Case]] = None
|
||||
cases: list[Case] | None = None
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@ -19,10 +19,10 @@ class IfElseNode(BaseNode):
|
||||
|
||||
_node_data: IfElseNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IfElseNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -31,7 +31,7 @@ class IfElseNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -83,7 +83,7 @@ class IfElseNode(BaseNode):
|
||||
else:
|
||||
# TODO: Update database then remove this
|
||||
# Fallback to old structure if cases are not defined
|
||||
input_conditions, group_result, final_result = _should_not_use_old_function(
|
||||
input_conditions, group_result, final_result = _should_not_use_old_function( # ty: ignore [deprecated]
|
||||
condition_processor=condition_processor,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=self._node_data.conditions or [],
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@ -17,7 +17,7 @@ class IterationNodeData(BaseIterationNodeData):
|
||||
Iteration Node Data.
|
||||
"""
|
||||
|
||||
parent_loop_id: Optional[str] = None # redundant field, not used currently
|
||||
parent_loop_id: str | None = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
is_parallel: bool = False # open the parallel mode or not
|
||||
@ -39,7 +39,7 @@ class IterationState(BaseIterationState):
|
||||
"""
|
||||
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Optional[Any] = None
|
||||
current_output: Any | None = None
|
||||
|
||||
class MetaData(BaseIterationState.MetaData):
|
||||
"""
|
||||
@ -48,7 +48,7 @@ class IterationState(BaseIterationState):
|
||||
|
||||
iterator_length: int
|
||||
|
||||
def get_last_output(self) -> Optional[Any]:
|
||||
def get_last_output(self) -> Any | None:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
@ -56,7 +56,7 @@ class IterationState(BaseIterationState):
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Optional[Any]:
|
||||
def get_current_output(self) -> Any | None:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
|
||||
@ -6,12 +6,12 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, wait
|
||||
from datetime import datetime
|
||||
from queue import Empty, Queue
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.workflow.entities.node_entities import (
|
||||
NodeRunResult,
|
||||
@ -67,10 +67,10 @@ class IterationNode(BaseNode):
|
||||
|
||||
_node_data: IterationNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -79,7 +79,7 @@ class IterationNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -89,7 +89,7 @@ class IterationNode(BaseNode):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: dict | None = None):
|
||||
return {
|
||||
"type": "iteration",
|
||||
"config": {
|
||||
@ -112,10 +112,10 @@ class IterationNode(BaseNode):
|
||||
if not variable:
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
||||
|
||||
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
|
||||
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
|
||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
|
||||
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
|
||||
if isinstance(variable, NoneSegment) or len(variable.value) == 0:
|
||||
# Try our best to preserve the type informat.
|
||||
if isinstance(variable, ArraySegment):
|
||||
output = variable.model_copy(update={"value": []})
|
||||
@ -424,7 +424,7 @@ class IterationNode(BaseNode):
|
||||
graph_engine: "GraphEngine",
|
||||
iteration_graph: Graph,
|
||||
iter_run_map: dict[str, float],
|
||||
parallel_mode_run_id: Optional[str] = None,
|
||||
parallel_mode_run_id: str | None = None,
|
||||
) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
run single iteration
|
||||
@ -441,8 +441,8 @@ class IterationNode(BaseNode):
|
||||
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
|
||||
next_index = int(current_index) + 1
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||
event.in_iteration_id = self.node_id
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: # ty: ignore [unresolved-attribute]
|
||||
event.in_iteration_id = self.node_id # ty: ignore [unresolved-attribute]
|
||||
|
||||
if (
|
||||
isinstance(event, BaseNodeEvent)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
@ -18,10 +18,10 @@ class IterationStartNode(BaseNode):
|
||||
|
||||
_node_data: IterationStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationStartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -30,7 +30,7 @@ class IterationStartNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -49,11 +49,11 @@ class MultipleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
|
||||
top_k: int
|
||||
score_threshold: Optional[float] = None
|
||||
score_threshold: float | None = None
|
||||
reranking_mode: str = "reranking_model"
|
||||
reranking_enable: bool = True
|
||||
reranking_model: Optional[RerankingModelConfig] = None
|
||||
weights: Optional[WeightedScoreConfig] = None
|
||||
reranking_model: RerankingModelConfig | None = None
|
||||
weights: WeightedScoreConfig | None = None
|
||||
|
||||
|
||||
class SingleRetrievalConfig(BaseModel):
|
||||
@ -104,8 +104,8 @@ class MetadataFilteringCondition(BaseModel):
|
||||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
@ -117,11 +117,11 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
query_variable_selector: list[str]
|
||||
dataset_ids: list[str]
|
||||
retrieval_mode: Literal["single", "multiple"]
|
||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig] = None
|
||||
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
|
||||
metadata_model_config: Optional[ModelConfig] = None
|
||||
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
|
||||
multiple_retrieval_config: MultipleRetrievalConfig | None = None
|
||||
single_retrieval_config: SingleRetrievalConfig | None = None
|
||||
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
|
||||
metadata_model_config: ModelConfig | None = None
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@property
|
||||
|
||||
@ -4,9 +4,9 @@ import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import Float, and_, func, or_, text
|
||||
from sqlalchemy import Float, and_, func, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@ -78,7 +78,7 @@ default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
@ -101,11 +101,11 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
previous_node_id: str | None = None,
|
||||
thread_pool_id: str | None = None,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
@ -125,10 +125,10 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -137,7 +137,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -259,7 +259,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
# fetch model config
|
||||
if node_data.single_retrieval_config is None:
|
||||
raise ValueError("single_retrieval_config is required")
|
||||
@ -291,7 +291,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
|
||||
@ -367,15 +367,12 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
.first()
|
||||
stmt = select(Document).where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
document = db.session.scalar(stmt)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"metadata": {
|
||||
@ -422,7 +419,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
|
||||
def _get_metadata_filter_condition(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
|
||||
document_query = db.session.query(Document).where(
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
Document.indexing_status == "completed",
|
||||
@ -514,7 +511,8 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> list[dict[str, Any]]:
|
||||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||
metadata_fields = db.session.scalars(stmt).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
if node_data.metadata_model_config is None:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
@ -573,12 +571,12 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return []
|
||||
return automatic_metadata_filters
|
||||
|
||||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
):
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return
|
||||
|
||||
@ -1,36 +1,43 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
_Condition = Literal[
|
||||
|
||||
class FilterOperator(StrEnum):
|
||||
# string conditions
|
||||
"contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"in",
|
||||
"empty",
|
||||
"not contains",
|
||||
"is not",
|
||||
"not in",
|
||||
"not empty",
|
||||
CONTAINS = "contains"
|
||||
START_WITH = "start with"
|
||||
END_WITH = "end with"
|
||||
IS = "is"
|
||||
IN = "in"
|
||||
EMPTY = "empty"
|
||||
NOT_CONTAINS = "not contains"
|
||||
IS_NOT = "is not"
|
||||
NOT_IN = "not in"
|
||||
NOT_EMPTY = "not empty"
|
||||
# number conditions
|
||||
"=",
|
||||
"≠",
|
||||
"<",
|
||||
">",
|
||||
"≥",
|
||||
"≤",
|
||||
]
|
||||
EQUAL = "="
|
||||
NOT_EQUAL = "≠"
|
||||
LESS_THAN = "<"
|
||||
GREATER_THAN = ">"
|
||||
GREATER_THAN_OR_EQUAL = "≥"
|
||||
LESS_THAN_OR_EQUAL = "≤"
|
||||
|
||||
|
||||
class Order(StrEnum):
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class FilterCondition(BaseModel):
|
||||
key: str = ""
|
||||
comparison_operator: _Condition = "contains"
|
||||
value: str | Sequence[str] = ""
|
||||
comparison_operator: FilterOperator = FilterOperator.CONTAINS
|
||||
# the value is bool if the filter operator is comparing with
|
||||
# a boolean constant.
|
||||
value: str | Sequence[str] | bool = ""
|
||||
|
||||
|
||||
class FilterBy(BaseModel):
|
||||
@ -38,10 +45,10 @@ class FilterBy(BaseModel):
|
||||
conditions: Sequence[FilterCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OrderBy(BaseModel):
|
||||
class OrderByConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
key: str = ""
|
||||
value: Literal["asc", "desc"] = "asc"
|
||||
value: Order = Order.ASC
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
@ -57,6 +64,6 @@ class ExtractConfig(BaseModel):
|
||||
class ListOperatorNodeData(BaseNodeData):
|
||||
variable: Sequence[str] = Field(default_factory=list)
|
||||
filter_by: FilterBy
|
||||
order_by: OrderBy
|
||||
order_by: OrderByConfig
|
||||
limit: Limit
|
||||
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)
|
||||
|
||||
@ -1,28 +1,50 @@
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, TypeAlias, TypeVar
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
||||
from .entities import ListOperatorNodeData
|
||||
from .entities import FilterOperator, ListOperatorNodeData, Order
|
||||
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
||||
|
||||
_SUPPORTED_TYPES_TUPLE = (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayStringSegment,
|
||||
ArrayBooleanSegment,
|
||||
)
|
||||
_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
|
||||
"""Returns the negation of a given filter function. If the original filter
|
||||
returns `True` for a value, the negated filter will return `False`, and vice versa.
|
||||
"""
|
||||
|
||||
def wrapper(value: _T) -> bool:
|
||||
return not filter_(value)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ListOperatorNode(BaseNode):
|
||||
_node_type = NodeType.LIST_OPERATOR
|
||||
|
||||
_node_data: ListOperatorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ListOperatorNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -31,7 +53,7 @@ class ListOperatorNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -45,8 +67,8 @@ class ListOperatorNode(BaseNode):
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
inputs: dict[str, list] = {}
|
||||
process_data: dict[str, list] = {}
|
||||
inputs: dict[str, Sequence[object]] = {}
|
||||
process_data: dict[str, Sequence[object]] = {}
|
||||
outputs: dict[str, Any] = {}
|
||||
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
|
||||
@ -69,11 +91,8 @@ class ListOperatorNode(BaseNode):
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
|
||||
error_message = (
|
||||
f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
|
||||
"or ArrayStringSegment"
|
||||
)
|
||||
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
|
||||
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||
)
|
||||
@ -122,9 +141,7 @@ class ListOperatorNode(BaseNode):
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
def _apply_filter(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
filter_func: Callable[[Any], bool]
|
||||
result: list[Any] = []
|
||||
for condition in self._node_data.filter_by.conditions:
|
||||
@ -154,33 +171,35 @@ class ListOperatorNode(BaseNode):
|
||||
)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayBooleanSegment):
|
||||
if not isinstance(condition.value, bool):
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
raise AssertionError("this statment should be unreachable.")
|
||||
return variable
|
||||
|
||||
def _apply_order(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
result = _order_string(order=self._node_data.order_by.value, array=variable.value)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayNumberSegment):
|
||||
result = _order_number(order=self._node_data.order_by.value, array=variable.value)
|
||||
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
|
||||
result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
result = _order_file(
|
||||
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
|
||||
)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable")
|
||||
|
||||
return variable
|
||||
|
||||
def _apply_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
result = variable.value[: self._node_data.limit.size]
|
||||
return variable.model_copy(update={"value": result})
|
||||
|
||||
def _extract_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
|
||||
if value < 1:
|
||||
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
|
||||
@ -232,11 +251,11 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
|
||||
case "empty":
|
||||
return lambda x: x == ""
|
||||
case "not contains":
|
||||
return lambda x: not _contains(value)(x)
|
||||
return _negation(_contains(value))
|
||||
case "is not":
|
||||
return lambda x: not _is(value)(x)
|
||||
return _negation(_is(value))
|
||||
case "not in":
|
||||
return lambda x: not _in(value)(x)
|
||||
return _negation(_in(value))
|
||||
case "not empty":
|
||||
return lambda x: x != ""
|
||||
case _:
|
||||
@ -248,7 +267,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
|
||||
case "in":
|
||||
return _in(value)
|
||||
case "not in":
|
||||
return lambda x: not _in(value)(x)
|
||||
return _negation(_in(value))
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
@ -271,6 +290,16 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]:
|
||||
match condition:
|
||||
case FilterOperator.IS:
|
||||
return _is(value)
|
||||
case FilterOperator.IS_NOT:
|
||||
return _negation(_is(value))
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
||||
extract_func: Callable[[File], Any]
|
||||
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
|
||||
@ -298,7 +327,7 @@ def _endswith(value: str) -> Callable[[str], bool]:
|
||||
return lambda x: x.endswith(value)
|
||||
|
||||
|
||||
def _is(value: str) -> Callable[[str], bool]:
|
||||
def _is(value: _T) -> Callable[[_T], bool]:
|
||||
return lambda x: x == value
|
||||
|
||||
|
||||
@ -330,21 +359,13 @@ def _ge(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x >= value
|
||||
|
||||
|
||||
def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]):
|
||||
return sorted(array, key=lambda x: x, reverse=order == "desc")
|
||||
|
||||
|
||||
def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
|
||||
return sorted(array, key=lambda x: x, reverse=order == "desc")
|
||||
|
||||
|
||||
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
|
||||
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
|
||||
extract_func: Callable[[File], Any]
|
||||
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
|
||||
extract_func = _get_file_extract_string_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
|
||||
elif order_by == "size":
|
||||
extract_func = _get_file_extract_number_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
|
||||
else:
|
||||
raise InvalidKeyError(f"Invalid order key: {order_by}")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@ -18,7 +18,7 @@ class ModelConfig(BaseModel):
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
enabled: bool
|
||||
variable_selector: Optional[list[str]] = None
|
||||
variable_selector: list[str] | None = None
|
||||
|
||||
|
||||
class VisionConfigOptions(BaseModel):
|
||||
@ -51,23 +51,40 @@ class PromptConfig(BaseModel):
|
||||
|
||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||
text: str = ""
|
||||
jinja2_text: Optional[str] = None
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
jinja2_text: Optional[str] = None
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
memory: Optional[MemoryConfig] = None
|
||||
memory: MemoryConfig | None = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
reasoning_format: Literal["separated", "tagged"] = Field(
|
||||
# Keep tagged as default for backward compatibility
|
||||
default="tagged",
|
||||
description=(
|
||||
"""
|
||||
Strategy for handling model reasoning output.
|
||||
|
||||
separated: Return clean text (without <think> tags) + reasoning_content field.
|
||||
Recommended for new workflows. Enables safe downstream parsing and
|
||||
workflow variable access: {{#node_id.reasoning_content#}}
|
||||
|
||||
tagged : Return original text (with <think> tags) + reasoning_content field.
|
||||
Maintains full backward compatibility while still providing reasoning_content
|
||||
for workflow automation. Frontend thinking panels work as before.
|
||||
"""
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
|
||||
@ -41,5 +41,5 @@ class FileTypeNotSupportError(LLMNodeError):
|
||||
|
||||
|
||||
class UnsupportedPromptContentTypeError(LLMNodeError):
|
||||
def __init__(self, *, type_name: str) -> None:
|
||||
def __init__(self, *, type_name: str):
|
||||
super().__init__(f"Prompt content type {type_name} is not supported.")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
@ -86,8 +86,8 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
|
||||
) -> Optional[TokenBufferMemory]:
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
|
||||
) -> TokenBufferMemory | None:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
@ -107,7 +107,7 @@ def fetch_memory(
|
||||
return memory
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
|
||||
@ -2,8 +2,9 @@ import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -59,7 +60,6 @@ from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
@ -96,6 +96,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -105,6 +106,9 @@ class LLMNode(BaseNode):
|
||||
|
||||
_node_data: LLMNodeData
|
||||
|
||||
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
|
||||
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
# Instance attributes specific to LLMNode.
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
@ -118,11 +122,11 @@ class LLMNode(BaseNode):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
previous_node_id: str | None = None,
|
||||
thread_pool_id: str | None = None,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
@ -142,10 +146,10 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LLMNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -154,7 +158,7 @@ class LLMNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -167,12 +171,13 @@ class LLMNode(BaseNode):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
node_inputs: Optional[dict[str, Any]] = None
|
||||
def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
node_inputs: dict[str, Any] | None = None
|
||||
process_data = None
|
||||
result_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
reasoning_content = None
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
@ -262,6 +267,7 @@ class LLMNode(BaseNode):
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self.node_id,
|
||||
reasoning_format=self._node_data.reasoning_format,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
@ -270,9 +276,20 @@ class LLMNode(BaseNode):
|
||||
if isinstance(event, RunStreamChunkEvent):
|
||||
yield event
|
||||
elif isinstance(event, ModelInvokeCompletedEvent):
|
||||
# Raw text
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
reasoning_content = event.reasoning_content or ""
|
||||
|
||||
# For downstream nodes, determine clean text based on reasoning_format
|
||||
if self._node_data.reasoning_format == "tagged":
|
||||
# Keep <think> tags for backward compatibility
|
||||
clean_text = result_text
|
||||
else:
|
||||
# Extract clean text from <think> tags
|
||||
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
@ -290,7 +307,12 @@ class LLMNode(BaseNode):
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
outputs = {
|
||||
"text": clean_text,
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
if self._file_outputs is not None:
|
||||
@ -342,13 +364,14 @@ class LLMNode(BaseNode):
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
user_id: str,
|
||||
structured_output_enabled: bool,
|
||||
structured_output: Optional[Mapping[str, Any]] = None,
|
||||
structured_output: Mapping[str, Any] | None = None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
node_id: str,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
@ -385,6 +408,7 @@ class LLMNode(BaseNode):
|
||||
file_saver=file_saver,
|
||||
file_outputs=file_outputs,
|
||||
node_id=node_id,
|
||||
reasoning_format=reasoning_format,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -394,6 +418,7 @@ class LLMNode(BaseNode):
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
node_id: str,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
|
||||
# For blocking mode
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
@ -401,6 +426,7 @@ class LLMNode(BaseNode):
|
||||
invoke_result=invoke_result,
|
||||
saver=file_saver,
|
||||
file_outputs=file_outputs,
|
||||
reasoning_format=reasoning_format,
|
||||
)
|
||||
yield event
|
||||
return
|
||||
@ -441,13 +467,66 @@ class LLMNode(BaseNode):
|
||||
except OutputParserError as e:
|
||||
raise LLMNodeError(f"Failed to parse structured output: {e}")
|
||||
|
||||
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
|
||||
# Extract reasoning content from <think> tags in the main text
|
||||
full_text = full_text_buffer.getvalue()
|
||||
|
||||
if reasoning_format == "tagged":
|
||||
# Keep <think> tags in text for backward compatibility
|
||||
clean_text = full_text
|
||||
reasoning_content = ""
|
||||
else:
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
|
||||
yield ModelInvokeCompletedEvent(
|
||||
# Use clean_text for separated mode, full_text for tagged mode
|
||||
text=clean_text if reasoning_format == "separated" else full_text,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _image_file_to_markdown(file: "File", /):
|
||||
text_chunk = f"})"
|
||||
return text_chunk
|
||||
|
||||
@classmethod
|
||||
def _split_reasoning(
|
||||
cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged"
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Split reasoning content from text based on reasoning_format strategy.
|
||||
|
||||
Args:
|
||||
text: Full text that may contain <think> blocks
|
||||
reasoning_format: Strategy for handling reasoning content
|
||||
- "separated": Remove <think> tags and return clean text + reasoning_content field
|
||||
- "tagged": Keep <think> tags in text, return empty reasoning_content
|
||||
|
||||
Returns:
|
||||
tuple of (clean_text, reasoning_content)
|
||||
"""
|
||||
|
||||
if reasoning_format == "tagged":
|
||||
return text, ""
|
||||
|
||||
# Find all <think>...</think> blocks (case-insensitive)
|
||||
matches = cls._THINK_PATTERN.findall(text)
|
||||
|
||||
# Extract reasoning content from all <think> blocks
|
||||
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
|
||||
|
||||
# Remove all <think>...</think> blocks from original text
|
||||
clean_text = cls._THINK_PATTERN.sub("", text)
|
||||
|
||||
# Clean up extra whitespace
|
||||
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||
|
||||
# Separated mode: always return clean text and reasoning_content
|
||||
return clean_text, reasoning_content or ""
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
@ -640,7 +719,7 @@ class LLMNode(BaseNode):
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
tenant_id: str,
|
||||
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
@ -748,7 +827,7 @@ class LLMNode(BaseNode):
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
and isinstance(prompt_messages[-1].content, list)
|
||||
):
|
||||
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
@ -883,7 +962,7 @@ class LLMNode(BaseNode):
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: dict | None = None):
|
||||
return {
|
||||
"type": "llm",
|
||||
"config": {
|
||||
@ -911,7 +990,7 @@ class LLMNode(BaseNode):
|
||||
def handle_list_messages(
|
||||
*,
|
||||
messages: Sequence[LLMNodeChatModelMessage],
|
||||
context: Optional[str],
|
||||
context: str | None,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
@ -975,6 +1054,7 @@ class LLMNode(BaseNode):
|
||||
invoke_result: LLMResult,
|
||||
saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
) -> ModelInvokeCompletedEvent:
|
||||
buffer = io.StringIO()
|
||||
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
|
||||
@ -984,10 +1064,24 @@ class LLMNode(BaseNode):
|
||||
):
|
||||
buffer.write(text_part)
|
||||
|
||||
# Extract reasoning content from <think> tags in the main text
|
||||
full_text = buffer.getvalue()
|
||||
|
||||
if reasoning_format == "tagged":
|
||||
# Keep <think> tags in text for backward compatibility
|
||||
clean_text = full_text
|
||||
reasoning_content = ""
|
||||
else:
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
|
||||
return ModelInvokeCompletedEvent(
|
||||
text=buffer.getvalue(),
|
||||
# Use clean_text for separated mode, full_text for tagged mode
|
||||
text=clean_text if reasoning_format == "separated" else full_text,
|
||||
usage=invoke_result.usage,
|
||||
finish_reason=None,
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -1156,7 +1250,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
|
||||
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
||||
):
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
@ -1165,7 +1259,8 @@ def _combine_message_content_with_role(
|
||||
return AssistantPromptMessage(content=contents)
|
||||
case PromptMessageRole.SYSTEM:
|
||||
return SystemPromptMessage(content=contents)
|
||||
raise NotImplementedError(f"Role {role} is not supported")
|
||||
case _:
|
||||
raise NotImplementedError(f"Role {role} is not supported")
|
||||
|
||||
|
||||
def _render_jinja2_message(
|
||||
@ -1261,7 +1356,7 @@ def _handle_memory_completion_mode(
|
||||
def _handle_completion_template(
|
||||
*,
|
||||
template: LLMNodeCompletionModelPromptTemplate,
|
||||
context: Optional[str],
|
||||
context: str | None,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
) -> Sequence[PromptMessage]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Annotated, Any, Literal, Optional
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field
|
||||
|
||||
@ -12,9 +12,11 @@ _VALID_VAR_TYPE = frozenset(
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
)
|
||||
|
||||
@ -33,7 +35,7 @@ class LoopVariableData(BaseModel):
|
||||
label: str
|
||||
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
|
||||
value_type: Literal["variable", "constant"]
|
||||
value: Optional[Any | list[str]] = None
|
||||
value: Any | list[str] | None = None
|
||||
|
||||
|
||||
class LoopNodeData(BaseLoopNodeData):
|
||||
@ -44,8 +46,8 @@ class LoopNodeData(BaseLoopNodeData):
|
||||
loop_count: int # Maximum number of loops
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
|
||||
|
||||
class LoopStartNodeData(BaseNodeData):
|
||||
@ -70,7 +72,7 @@ class LoopState(BaseLoopState):
|
||||
"""
|
||||
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Optional[Any] = None
|
||||
current_output: Any | None = None
|
||||
|
||||
class MetaData(BaseLoopState.MetaData):
|
||||
"""
|
||||
@ -79,7 +81,7 @@ class LoopState(BaseLoopState):
|
||||
|
||||
loop_length: int
|
||||
|
||||
def get_last_output(self) -> Optional[Any]:
|
||||
def get_last_output(self) -> Any | None:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
@ -87,7 +89,7 @@ class LoopState(BaseLoopState):
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Optional[Any]:
|
||||
def get_current_output(self) -> Any | None:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
@ -18,10 +18,10 @@ class LoopEndNode(BaseNode):
|
||||
|
||||
_node_data: LoopEndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopEndNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -30,7 +30,7 @@ class LoopEndNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import (
|
||||
@ -54,10 +54,10 @@ class LoopNode(BaseNode):
|
||||
|
||||
_node_data: LoopNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -66,7 +66,7 @@ class LoopNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -289,6 +289,8 @@ class LoopNode(BaseNode):
|
||||
Returns:
|
||||
dict: {'check_break_result': bool}
|
||||
"""
|
||||
condition_selectors = self._extract_selectors_from_conditions(break_conditions)
|
||||
extended_selectors = {**loop_variable_selectors, **condition_selectors}
|
||||
# Run workflow
|
||||
rst = graph_engine.run()
|
||||
current_index_variable = variable_pool.get([self.node_id, "index"])
|
||||
@ -299,8 +301,8 @@ class LoopNode(BaseNode):
|
||||
check_break_result = False
|
||||
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id:
|
||||
event.in_loop_id = self.node_id
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id: # ty: ignore [unresolved-attribute]
|
||||
event.in_loop_id = self.node_id # ty: ignore [unresolved-attribute]
|
||||
|
||||
if (
|
||||
isinstance(event, BaseNodeEvent)
|
||||
@ -314,31 +316,30 @@ class LoopNode(BaseNode):
|
||||
and event.node_type == NodeType.LOOP_END
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
# Check if variables in break conditions exist and process conditions
|
||||
# Allow loop internal variables to be used in break conditions
|
||||
available_conditions = []
|
||||
for condition in break_conditions:
|
||||
variable = self.graph_runtime_state.variable_pool.get(condition.variable_selector)
|
||||
if variable:
|
||||
available_conditions.append(condition)
|
||||
|
||||
# Process conditions if at least one variable is available
|
||||
if available_conditions:
|
||||
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=available_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
if check_break_result:
|
||||
break
|
||||
else:
|
||||
check_break_result = True
|
||||
check_break_result = True
|
||||
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||
break
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||
|
||||
# Check if all variables in break conditions exist
|
||||
exists_variable = False
|
||||
for condition in break_conditions:
|
||||
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector):
|
||||
exists_variable = False
|
||||
break
|
||||
else:
|
||||
exists_variable = True
|
||||
if exists_variable:
|
||||
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
if check_break_result:
|
||||
break
|
||||
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# Loop run failed
|
||||
@ -400,21 +401,21 @@ class LoopNode(BaseNode):
|
||||
else:
|
||||
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
|
||||
|
||||
# Remove all nodes outputs from variable pool
|
||||
for node_id in loop_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
_outputs = {}
|
||||
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
|
||||
_outputs: dict[str, Segment | int | None] = {}
|
||||
for loop_variable_key, loop_variable_selector in extended_selectors.items():
|
||||
_loop_variable_segment = variable_pool.get(loop_variable_selector)
|
||||
if _loop_variable_segment:
|
||||
_outputs[loop_variable_key] = _loop_variable_segment.value
|
||||
_outputs[loop_variable_key] = _loop_variable_segment
|
||||
else:
|
||||
_outputs[loop_variable_key] = None
|
||||
|
||||
_outputs["loop_round"] = current_index + 1
|
||||
self._node_data.outputs = _outputs
|
||||
|
||||
# Remove all nodes outputs from variable pool
|
||||
for node_id in loop_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
if check_break_result:
|
||||
return {"check_break_result": True}
|
||||
|
||||
@ -433,6 +434,13 @@ class LoopNode(BaseNode):
|
||||
|
||||
return {"check_break_result": False}
|
||||
|
||||
def _extract_selectors_from_conditions(self, conditions: list) -> dict[str, list[str]]:
|
||||
return {
|
||||
condition.variable_selector[1]: condition.variable_selector
|
||||
for condition in conditions
|
||||
if condition.variable_selector and len(condition.variable_selector) >= 2
|
||||
}
|
||||
|
||||
def _handle_event_metadata(
|
||||
self,
|
||||
*,
|
||||
@ -522,21 +530,33 @@ class LoopNode(BaseNode):
|
||||
return variable_mapping
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
if var_type in ["array[string]", "array[number]", "array[object]"]:
|
||||
if value and isinstance(value, str):
|
||||
value = json.loads(value)
|
||||
# TODO: Refactor for maintainability:
|
||||
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
|
||||
# 2. Consider moving this method to LoopVariableData class for better encapsulation
|
||||
if not var_type.is_array_type() or var_type == SegmentType.ARRAY_BOOLEAN:
|
||||
value = original_value
|
||||
elif var_type in [
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
]:
|
||||
if original_value and isinstance(original_value, str):
|
||||
value = json.loads(original_value)
|
||||
else:
|
||||
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
|
||||
value = []
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
try:
|
||||
return build_segment_with_type(var_type, value)
|
||||
return build_segment_with_type(var_type, value=value)
|
||||
except TypeMismatchError as type_exc:
|
||||
# Attempt to parse the value as a JSON-encoded string, if applicable.
|
||||
if not isinstance(value, str):
|
||||
if not isinstance(original_value, str):
|
||||
raise
|
||||
try:
|
||||
value = json.loads(value)
|
||||
value = json.loads(original_value)
|
||||
except ValueError:
|
||||
raise type_exc
|
||||
return build_segment_with_type(var_type, value)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
@ -18,10 +18,10 @@ class LoopStartNode(BaseNode):
|
||||
|
||||
_node_data: LoopStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopStartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -30,7 +30,7 @@ class LoopStartNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -1,14 +1,46 @@
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
Field,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
|
||||
_OLD_BOOL_TYPE_NAME = "bool"
|
||||
_OLD_SELECT_TYPE_NAME = "select"
|
||||
|
||||
_VALID_PARAMETER_TYPES = frozenset(
|
||||
[
|
||||
SegmentType.STRING, # "string",
|
||||
SegmentType.NUMBER, # "number",
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
_OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node
|
||||
_OLD_SELECT_TYPE_NAME, # string type with enumeration choices.
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class _ParameterConfigError(Exception):
|
||||
pass
|
||||
def _validate_type(parameter_type: str) -> SegmentType:
|
||||
if not isinstance(parameter_type, str):
|
||||
raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}")
|
||||
if parameter_type not in _VALID_PARAMETER_TYPES:
|
||||
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
|
||||
|
||||
if parameter_type == _OLD_BOOL_TYPE_NAME:
|
||||
return SegmentType.BOOLEAN
|
||||
elif parameter_type == _OLD_SELECT_TYPE_NAME:
|
||||
return SegmentType.STRING
|
||||
return SegmentType(parameter_type)
|
||||
|
||||
|
||||
class ParameterConfig(BaseModel):
|
||||
@ -17,8 +49,8 @@ class ParameterConfig(BaseModel):
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"]
|
||||
options: Optional[list[str]] = None
|
||||
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
|
||||
options: list[str] | None = None
|
||||
description: str
|
||||
required: bool
|
||||
|
||||
@ -32,17 +64,20 @@ class ParameterConfig(BaseModel):
|
||||
return str(value)
|
||||
|
||||
def is_array_type(self) -> bool:
|
||||
return self.type in ("array[string]", "array[number]", "array[object]")
|
||||
return self.type.is_array_type()
|
||||
|
||||
def element_type(self) -> Literal["string", "number", "object"]:
|
||||
if self.type == "array[number]":
|
||||
return "number"
|
||||
elif self.type == "array[string]":
|
||||
return "string"
|
||||
elif self.type == "array[object]":
|
||||
return "object"
|
||||
else:
|
||||
raise _ParameterConfigError(f"{self.type} is not array type.")
|
||||
def element_type(self) -> SegmentType:
|
||||
"""Return the element type of the parameter.
|
||||
|
||||
Raises a ValueError if the parameter's type is not an array type.
|
||||
"""
|
||||
element_type = self.type.element_type()
|
||||
# At this point, self.type is guaranteed to be one of `ARRAY_STRING`,
|
||||
# `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`.
|
||||
#
|
||||
# See: _VALID_PARAMETER_TYPES for reference.
|
||||
assert element_type is not None, f"the element type should not be None, {self.type=}"
|
||||
return element_type
|
||||
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
@ -53,8 +88,8 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
query: list[str]
|
||||
parameters: list[ParameterConfig]
|
||||
instruction: Optional[str] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
instruction: str | None = None
|
||||
memory: MemoryConfig | None = None
|
||||
reasoning_mode: Literal["function_call", "prompt"]
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@ -63,7 +98,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
def set_reasoning_mode(cls, v) -> str:
|
||||
return v or "function_call"
|
||||
|
||||
def get_parameter_json_schema(self) -> dict:
|
||||
def get_parameter_json_schema(self):
|
||||
"""
|
||||
Get parameter json schema.
|
||||
|
||||
@ -74,16 +109,18 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
for parameter in self.parameters:
|
||||
parameter_schema: dict[str, Any] = {"description": parameter.description}
|
||||
|
||||
if parameter.type in {"string", "select"}:
|
||||
if parameter.type == SegmentType.STRING:
|
||||
parameter_schema["type"] = "string"
|
||||
elif parameter.type.startswith("array"):
|
||||
elif parameter.type.is_array_type():
|
||||
parameter_schema["type"] = "array"
|
||||
nested_type = parameter.type[6:-1]
|
||||
parameter_schema["items"] = {"type": nested_type}
|
||||
element_type = parameter.type.element_type()
|
||||
if element_type is None:
|
||||
raise AssertionError("element type should not be None.")
|
||||
parameter_schema["items"] = {"type": element_type.value}
|
||||
else:
|
||||
parameter_schema["type"] = parameter.type
|
||||
|
||||
if parameter.type == "select":
|
||||
if parameter.options:
|
||||
parameter_schema["enum"] = parameter.options
|
||||
|
||||
parameters["properties"][parameter.name] = parameter_schema
|
||||
|
||||
@ -1,3 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
|
||||
class ParameterExtractorNodeError(ValueError):
|
||||
"""Base error for ParameterExtractorNode."""
|
||||
|
||||
@ -48,3 +53,23 @@ class InvalidArrayValueError(ParameterExtractorNodeError):
|
||||
|
||||
class InvalidModelModeError(ParameterExtractorNodeError):
|
||||
"""Raised when the model mode is invalid."""
|
||||
|
||||
|
||||
class InvalidValueTypeError(ParameterExtractorNodeError):
|
||||
def __init__(
|
||||
self,
|
||||
/,
|
||||
parameter_name: str,
|
||||
expected_type: SegmentType,
|
||||
actual_type: SegmentType | None,
|
||||
value: Any,
|
||||
):
|
||||
message = (
|
||||
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
|
||||
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
|
||||
)
|
||||
super().__init__(message)
|
||||
self.parameter_name = parameter_name
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
self.value = value
|
||||
|
||||
@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import File
|
||||
@ -26,7 +26,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
@ -39,22 +39,20 @@ from factories.variable_factory import build_segment_with_type
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
from .exc import (
|
||||
InvalidArrayValueError,
|
||||
InvalidBoolValueError,
|
||||
InvalidInvokeResultError,
|
||||
InvalidModelModeError,
|
||||
InvalidModelTypeError,
|
||||
InvalidNumberOfParametersError,
|
||||
InvalidNumberValueError,
|
||||
InvalidSelectValueError,
|
||||
InvalidStringValueError,
|
||||
InvalidTextContentTypeError,
|
||||
InvalidValueTypeError,
|
||||
ModelSchemaNotFoundError,
|
||||
ParameterExtractorNodeError,
|
||||
RequiredParameterMissingError,
|
||||
)
|
||||
from .prompts import (
|
||||
CHAT_EXAMPLE,
|
||||
CHAT_GENERATE_JSON_PROMPT,
|
||||
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
|
||||
COMPLETION_GENERATE_JSON_PROMPT,
|
||||
FUNCTION_CALLING_EXTRACTOR_EXAMPLE,
|
||||
@ -97,10 +95,10 @@ class ParameterExtractorNode(BaseNode):
|
||||
|
||||
_node_data: ParameterExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ParameterExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -109,7 +107,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -118,11 +116,11 @@ class ParameterExtractorNode(BaseNode):
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
_model_instance: Optional[ModelInstance] = None
|
||||
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
|
||||
_model_instance: ModelInstance | None = None
|
||||
_model_config: ModelConfigWithCredentialsEntity | None = None
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: dict | None = None):
|
||||
return {
|
||||
"model": {
|
||||
"prompt_templates": {
|
||||
@ -142,7 +140,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
node_data = cast(ParameterExtractorNodeData, self._node_data)
|
||||
node_data = self._node_data
|
||||
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
|
||||
query = variable.text if variable else ""
|
||||
|
||||
@ -297,7 +295,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
|
||||
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
@ -332,9 +330,9 @@ class ParameterExtractorNode(BaseNode):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
"""
|
||||
Generate function call prompt.
|
||||
@ -414,9 +412,9 @@ class ParameterExtractorNode(BaseNode):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate prompt engineering prompt.
|
||||
@ -452,9 +450,9 @@ class ParameterExtractorNode(BaseNode):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate completion prompt.
|
||||
@ -486,9 +484,9 @@ class ParameterExtractorNode(BaseNode):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate chat prompt.
|
||||
@ -548,10 +546,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
"""
|
||||
Validate result.
|
||||
"""
|
||||
def _validate_result(self, data: ParameterExtractorNodeData, result: dict):
|
||||
if len(data.parameters) != len(result):
|
||||
raise InvalidNumberOfParametersError("Invalid number of parameters")
|
||||
|
||||
@ -559,105 +554,110 @@ class ParameterExtractorNode(BaseNode):
|
||||
if parameter.required and parameter.name not in result:
|
||||
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
|
||||
|
||||
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
|
||||
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
|
||||
raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
|
||||
raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
|
||||
raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type.startswith("array"):
|
||||
parameters = result.get(parameter.name)
|
||||
if not isinstance(parameters, list):
|
||||
raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
|
||||
nested_type = parameter.type[6:-1]
|
||||
for item in parameters:
|
||||
if nested_type == "number" and not isinstance(item, int | float):
|
||||
raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
|
||||
if nested_type == "string" and not isinstance(item, str):
|
||||
raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
|
||||
if nested_type == "object" and not isinstance(item, dict):
|
||||
raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
|
||||
param_value = result.get(parameter.name)
|
||||
if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL):
|
||||
inferred_type = SegmentType.infer_segment_type(param_value)
|
||||
raise InvalidValueTypeError(
|
||||
parameter_name=parameter.name,
|
||||
expected_type=parameter.type,
|
||||
actual_type=inferred_type,
|
||||
value=param_value,
|
||||
)
|
||||
if parameter.type == SegmentType.STRING and parameter.options:
|
||||
if param_value not in parameter.options:
|
||||
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
return result
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
@staticmethod
|
||||
def _transform_number(value: int | float | str | bool) -> int | float | None:
|
||||
"""
|
||||
Attempts to transform the input into an integer or float.
|
||||
|
||||
Returns:
|
||||
int or float: The transformed number if the conversion is successful.
|
||||
None: If the transformation fails.
|
||||
|
||||
Note:
|
||||
Boolean values `True` and `False` are converted to integers `1` and `0`, respectively.
|
||||
This behavior ensures compatibility with existing workflows that may use boolean types as integers.
|
||||
"""
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
elif isinstance(value, (int, float)):
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
return None
|
||||
if "." in value:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
|
||||
"""
|
||||
Transform result into standard format.
|
||||
"""
|
||||
transformed_result = {}
|
||||
transformed_result: dict[str, Any] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.name in result:
|
||||
param_value = result[parameter.name]
|
||||
# transform value
|
||||
if parameter.type == "number":
|
||||
if isinstance(result[parameter.name], int | float):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
elif isinstance(result[parameter.name], str):
|
||||
try:
|
||||
if "." in result[parameter.name]:
|
||||
result[parameter.name] = float(result[parameter.name])
|
||||
else:
|
||||
result[parameter.name] = int(result[parameter.name])
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
# TODO: bool is not supported in the current version
|
||||
# elif parameter.type == 'bool':
|
||||
# if isinstance(result[parameter.name], bool):
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
# elif isinstance(result[parameter.name], str):
|
||||
# if result[parameter.name].lower() in ['true', 'false']:
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
|
||||
# elif isinstance(result[parameter.name], int):
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
elif parameter.type in {"string", "select"}:
|
||||
if isinstance(result[parameter.name], str):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
if parameter.type == SegmentType.NUMBER:
|
||||
transformed = self._transform_number(param_value)
|
||||
if transformed is not None:
|
||||
transformed_result[parameter.name] = transformed
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
if isinstance(result[parameter.name], (bool, int)):
|
||||
transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
# elif isinstance(result[parameter.name], str):
|
||||
# if result[parameter.name].lower() in ["true", "false"]:
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true")
|
||||
elif parameter.type == SegmentType.STRING:
|
||||
if isinstance(param_value, str):
|
||||
transformed_result[parameter.name] = param_value
|
||||
elif parameter.is_array_type():
|
||||
if isinstance(result[parameter.name], list):
|
||||
if isinstance(param_value, list):
|
||||
nested_type = parameter.element_type()
|
||||
assert nested_type is not None
|
||||
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
|
||||
transformed_result[parameter.name] = segment_value
|
||||
for item in result[parameter.name]:
|
||||
if nested_type == "number":
|
||||
if isinstance(item, int | float):
|
||||
segment_value.value.append(item)
|
||||
elif isinstance(item, str):
|
||||
try:
|
||||
if "." in item:
|
||||
segment_value.value.append(float(item))
|
||||
else:
|
||||
segment_value.value.append(int(item))
|
||||
except ValueError:
|
||||
pass
|
||||
elif nested_type == "string":
|
||||
for item in param_value:
|
||||
if nested_type == SegmentType.NUMBER:
|
||||
transformed = self._transform_number(item)
|
||||
if transformed is not None:
|
||||
segment_value.value.append(transformed)
|
||||
elif nested_type == SegmentType.STRING:
|
||||
if isinstance(item, str):
|
||||
segment_value.value.append(item)
|
||||
elif nested_type == "object":
|
||||
elif nested_type == SegmentType.OBJECT:
|
||||
if isinstance(item, dict):
|
||||
segment_value.value.append(item)
|
||||
elif nested_type == SegmentType.BOOLEAN:
|
||||
if isinstance(item, bool):
|
||||
segment_value.value.append(item)
|
||||
|
||||
if parameter.name not in transformed_result:
|
||||
if parameter.type == "number":
|
||||
transformed_result[parameter.name] = 0
|
||||
elif parameter.type == "bool":
|
||||
transformed_result[parameter.name] = False
|
||||
elif parameter.type in {"string", "select"}:
|
||||
transformed_result[parameter.name] = ""
|
||||
elif parameter.type.startswith("array"):
|
||||
if parameter.type.is_array_type():
|
||||
transformed_result[parameter.name] = build_segment_with_type(
|
||||
segment_type=SegmentType(parameter.type), value=[]
|
||||
)
|
||||
elif parameter.type in (SegmentType.STRING, SegmentType.SECRET):
|
||||
transformed_result[parameter.name] = ""
|
||||
elif parameter.type == SegmentType.NUMBER:
|
||||
transformed_result[parameter.name] = 0
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
transformed_result[parameter.name] = False
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
return transformed_result
|
||||
|
||||
def _extract_complete_json_response(self, result: str) -> Optional[dict]:
|
||||
def _extract_complete_json_response(self, result: str) -> dict | None:
|
||||
"""
|
||||
Extract complete json response.
|
||||
"""
|
||||
@ -672,7 +672,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
|
||||
"""
|
||||
Extract json from tool call.
|
||||
"""
|
||||
@ -691,7 +691,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData):
|
||||
"""
|
||||
Generate default result.
|
||||
"""
|
||||
@ -711,7 +711,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
@ -738,7 +738,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
@ -753,7 +753,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
|
||||
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
@ -774,7 +774,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str],
|
||||
context: str | None,
|
||||
) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
@ -16,8 +14,8 @@ class QuestionClassifierNodeData(BaseNodeData):
|
||||
query_variable_selector: list[str]
|
||||
model: ModelConfig
|
||||
classes: list[ClassConfig]
|
||||
instruction: Optional[str] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
instruction: str | None = None
|
||||
memory: MemoryConfig | None = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@property
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@ -59,11 +59,11 @@ class QuestionClassifierNode(BaseNode):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
previous_node_id: str | None = None,
|
||||
thread_pool_id: str | None = None,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
@ -83,10 +83,10 @@ class QuestionClassifierNode(BaseNode):
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = QuestionClassifierNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -95,7 +95,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -109,7 +109,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
node_data = cast(QuestionClassifierNodeData, self._node_data)
|
||||
node_data = self._node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# extract variables
|
||||
@ -275,7 +275,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: dict | None = None):
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
@ -288,7 +288,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str],
|
||||
context: str | None,
|
||||
) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||
@ -331,7 +331,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
@ -15,10 +15,10 @@ class StartNode(BaseNode):
|
||||
|
||||
_node_data: StartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = StartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -27,7 +27,7 @@ class StartNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
@ -18,10 +18,10 @@ class TemplateTransformNode(BaseNode):
|
||||
|
||||
_node_data: TemplateTransformNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -30,7 +30,7 @@ class TemplateTransformNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -40,7 +40,7 @@ class TemplateTransformNode(BaseNode):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: dict | None = None):
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -45,7 +45,7 @@ class ToolNode(BaseNode):
|
||||
|
||||
_node_data: ToolNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ToolNodeData.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
@ -57,7 +57,7 @@ class ToolNode(BaseNode):
|
||||
Run the tool node
|
||||
"""
|
||||
|
||||
node_data = cast(ToolNodeData, self._node_data)
|
||||
node_data = self._node_data
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
@ -439,7 +439,7 @@ class ToolNode(BaseNode):
|
||||
|
||||
return result
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -448,7 +448,7 @@ class ToolNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
@ -33,4 +31,4 @@ class VariableAssignerNodeData(BaseNodeData):
|
||||
type: str = "variable-assigner"
|
||||
output_type: str
|
||||
variables: list[list[str]]
|
||||
advanced_settings: Optional[AdvancedSettings] = None
|
||||
advanced_settings: AdvancedSettings | None = None
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
@ -15,10 +15,10 @@ class VariableAggregatorNode(BaseNode):
|
||||
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -27,7 +27,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -16,7 +16,7 @@ class UpdatedVariable(BaseModel):
|
||||
name: str
|
||||
selector: Sequence[str]
|
||||
value_type: SegmentType
|
||||
new_value: Any
|
||||
new_value: Any = None
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=MutableMapping[str, Any])
|
||||
@ -25,7 +25,7 @@ _T = TypeVar("_T", bound=MutableMapping[str, Any])
|
||||
def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
raise Exception("selector too short")
|
||||
node_id, var_name = selector[:2]
|
||||
_, var_name = selector[:2]
|
||||
return UpdatedVariable(
|
||||
name=var_name,
|
||||
selector=list(selector[:2]),
|
||||
|
||||
@ -11,7 +11,7 @@ from .exc import VariableOperatorNodeError
|
||||
class ConversationVariableUpdaterImpl:
|
||||
_engine: Engine | None
|
||||
|
||||
def __init__(self, engine: Engine | None = None) -> None:
|
||||
def __init__(self, engine: Engine | None = None):
|
||||
self._engine = engine
|
||||
|
||||
def _get_engine(self) -> Engine:
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.segments import BooleanSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
@ -29,10 +30,10 @@ class VariableAssignerNode(BaseNode):
|
||||
|
||||
_node_data: VariableAssignerData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -41,7 +42,7 @@ class VariableAssignerNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -57,10 +58,10 @@ class VariableAssignerNode(BaseNode):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
previous_node_id: str | None = None,
|
||||
thread_pool_id: str | None = None,
|
||||
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
@ -158,8 +159,8 @@ class VariableAssignerNode(BaseNode):
|
||||
def get_zero_value(t: SegmentType):
|
||||
# TODO(QuantumGhost): this should be a method of `SegmentType`.
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
return variable_factory.build_segment([])
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN:
|
||||
return variable_factory.build_segment_with_type(t, [])
|
||||
case SegmentType.OBJECT:
|
||||
return variable_factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
@ -170,5 +171,7 @@ def get_zero_value(t: SegmentType):
|
||||
return variable_factory.build_segment(0.0)
|
||||
case SegmentType.NUMBER:
|
||||
return variable_factory.build_segment(0)
|
||||
case SegmentType.BOOLEAN:
|
||||
return BooleanSegment(value=False)
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported variable type: {t}")
|
||||
|
||||
@ -4,9 +4,11 @@ from core.variables import SegmentType
|
||||
EMPTY_VALUE_MAPPING = {
|
||||
SegmentType.STRING: "",
|
||||
SegmentType.NUMBER: 0,
|
||||
SegmentType.BOOLEAN: False,
|
||||
SegmentType.OBJECT: {},
|
||||
SegmentType.ARRAY_ANY: [],
|
||||
SegmentType.ARRAY_STRING: [],
|
||||
SegmentType.ARRAY_NUMBER: [],
|
||||
SegmentType.ARRAY_OBJECT: [],
|
||||
SegmentType.ARRAY_BOOLEAN: [],
|
||||
}
|
||||
|
||||
@ -32,5 +32,5 @@ class ConversationIDNotFoundError(VariableOperatorNodeError):
|
||||
|
||||
|
||||
class InvalidDataError(VariableOperatorNodeError):
|
||||
def __init__(self, message: str) -> None:
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
@ -16,28 +16,15 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.FLOAT,
|
||||
SegmentType.BOOLEAN,
|
||||
}
|
||||
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
|
||||
# Only number variable can be added, subtracted, multiplied or divided
|
||||
return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
|
||||
case Operation.APPEND | Operation.EXTEND:
|
||||
case Operation.APPEND | Operation.EXTEND | Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
|
||||
# Only array variable can be appended or extended
|
||||
return variable_type in {
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_FILE,
|
||||
}
|
||||
case Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
|
||||
# Only array variable can have elements removed
|
||||
return variable_type in {
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_FILE,
|
||||
}
|
||||
return variable_type.is_array_type()
|
||||
case _:
|
||||
return False
|
||||
|
||||
@ -50,7 +37,7 @@ def is_variable_input_supported(*, operation: Operation):
|
||||
|
||||
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
match variable_type:
|
||||
case SegmentType.STRING | SegmentType.OBJECT:
|
||||
case SegmentType.STRING | SegmentType.OBJECT | SegmentType.BOOLEAN:
|
||||
return operation in {Operation.OVER_WRITE, Operation.SET}
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
return operation in {
|
||||
@ -72,6 +59,9 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
case SegmentType.STRING:
|
||||
return isinstance(value, str)
|
||||
|
||||
case SegmentType.BOOLEAN:
|
||||
return isinstance(value, bool)
|
||||
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
if not isinstance(value, int | float):
|
||||
return False
|
||||
@ -91,6 +81,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
return isinstance(value, int | float)
|
||||
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
|
||||
return isinstance(value, dict)
|
||||
case SegmentType.ARRAY_BOOLEAN if operation == Operation.APPEND:
|
||||
return isinstance(value, bool)
|
||||
|
||||
# Array & Extend / Overwrite
|
||||
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
@ -101,6 +93,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
|
||||
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
|
||||
case SegmentType.ARRAY_BOOLEAN if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, bool) for item in value)
|
||||
|
||||
case _:
|
||||
return False
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
@ -58,10 +58,10 @@ class VariableAssignerNode(BaseNode):
|
||||
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -70,7 +70,7 @@ class VariableAssignerNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
Reference in New Issue
Block a user