mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 22:18:15 +08:00
Merge remote-tracking branch 'upstream/main' into feat/human-input-merge-again
This commit is contained in:
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
@ -167,7 +169,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
for_log: bool = False,
|
||||
strategy: "PluginAgentStrategy",
|
||||
strategy: PluginAgentStrategy,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
@ -233,7 +235,18 @@ class AgentNode(Node[AgentNodeData]):
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
if value_param and value_param.get("type", "") == "variable":
|
||||
variable_selector = value_param.get("value")
|
||||
if not variable_selector:
|
||||
raise ValueError("Variable selector is missing for a variable-type parameter.")
|
||||
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
|
||||
params[key] = variable.value
|
||||
else:
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
@ -328,7 +341,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
def _generate_credentials(
|
||||
self,
|
||||
parameters: dict[str, Any],
|
||||
) -> "InvokeCredentials":
|
||||
) -> InvokeCredentials:
|
||||
"""
|
||||
Generate credentials based on the given agent parameters.
|
||||
"""
|
||||
@ -442,9 +455,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
def _filter_mcp_type_tool(
|
||||
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter MCP type tool
|
||||
:param strategy: plugin agent strategy
|
||||
@ -483,7 +494,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json_list: list[dict] = []
|
||||
json_list: list[dict | list] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
@ -557,13 +568,18 @@ class AgentNode(Node[AgentNodeData]):
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
if isinstance(message.message.json_object, dict):
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
else:
|
||||
msg_metadata = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
agent_execution_metadata = {}
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
@ -672,7 +688,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
yield agent_log
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
|
||||
# Step 1: append each agent log as its own dict.
|
||||
if agent_logs:
|
||||
|
||||
@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError):
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentMaxIterationError(AgentNodeError):
|
||||
"""Exception raised when the agent exceeds the maximum iteration limit."""
|
||||
|
||||
def __init__(self, max_iteration: int):
|
||||
self.max_iteration = max_iteration
|
||||
super().__init__(
|
||||
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
|
||||
f"The agent was unable to complete the task within the allowed number of iterations."
|
||||
)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
@ -111,7 +113,7 @@ class DefaultValue(BaseModel):
|
||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> "DefaultValue":
|
||||
def validate_value_type(self) -> DefaultValue:
|
||||
# Type validation configuration
|
||||
type_validators = {
|
||||
DefaultValueType.STRING: {
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import operator
|
||||
@ -72,7 +74,7 @@ class Node(Generic[NodeDataT]):
|
||||
in its output.
|
||||
"""
|
||||
|
||||
node_type: ClassVar["NodeType"]
|
||||
node_type: ClassVar[NodeType]
|
||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
||||
|
||||
@ -211,14 +213,14 @@ class Node(Generic[NodeDataT]):
|
||||
return None
|
||||
|
||||
# Global registry populated via __init_subclass__
|
||||
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
|
||||
_registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> None:
|
||||
self._graph_init_params = graph_init_params
|
||||
self.id = id
|
||||
@ -254,7 +256,7 @@ class Node(Generic[NodeDataT]):
|
||||
return
|
||||
|
||||
@property
|
||||
def graph_init_params(self) -> "GraphInitParams":
|
||||
def graph_init_params(self) -> GraphInitParams:
|
||||
return self._graph_init_params
|
||||
|
||||
@property
|
||||
@ -300,6 +302,10 @@ class Node(Generic[NodeDataT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _should_stop(self) -> bool:
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
@ -368,6 +374,21 @@ class Node(Generic[NodeDataT]):
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
|
||||
if self._should_stop():
|
||||
error_message = "Execution cancelled"
|
||||
yield NodeRunFailedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
),
|
||||
error=error_message,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Node %s failed to run", self._node_id)
|
||||
result = NodeRunResult(
|
||||
@ -474,7 +495,7 @@ class Node(Generic[NodeDataT]):
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@classmethod
|
||||
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
|
||||
def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]:
|
||||
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
||||
|
||||
Import all modules under core.workflow.nodes so subclasses register themselves on import.
|
||||
@ -484,12 +505,8 @@ class Node(Generic[NodeDataT]):
|
||||
import core.workflow.nodes as _nodes_pkg
|
||||
|
||||
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
|
||||
# Avoid importing modules that depend on the registry to prevent circular imports
|
||||
# e.g. node_factory imports node_mapping which builds the mapping here.
|
||||
if _modname in {
|
||||
"core.workflow.nodes.node_factory",
|
||||
"core.workflow.nodes.node_mapping",
|
||||
}:
|
||||
# Avoid importing modules that depend on the registry to prevent circular imports.
|
||||
if _modname == "core.workflow.nodes.node_mapping":
|
||||
continue
|
||||
importlib.import_module(_modname)
|
||||
|
||||
|
||||
@ -4,6 +4,8 @@ This module provides a unified template structure for both Answer and End nodes,
|
||||
similar to SegmentGroup but focused on template representation without values.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
@ -58,7 +60,7 @@ class Template:
|
||||
segments: list[TemplateSegmentUnion]
|
||||
|
||||
@classmethod
|
||||
def from_answer_template(cls, template_str: str) -> "Template":
|
||||
def from_answer_template(cls, template_str: str) -> Template:
|
||||
"""Create a Template from an Answer node template string.
|
||||
|
||||
Example:
|
||||
@ -107,7 +109,7 @@ class Template:
|
||||
return cls(segments=segments)
|
||||
|
||||
@classmethod
|
||||
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
|
||||
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template:
|
||||
"""Create a Template from an End node outputs configuration.
|
||||
|
||||
End nodes are treated as templates of concatenated variables with newlines.
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
@ -13,6 +12,7 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
@ -20,9 +20,41 @@ from .exc import (
|
||||
OutputValidationError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class CodeNode(Node[CodeNodeData]):
|
||||
node_type = NodeType.CODE
|
||||
_DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
|
||||
Python3CodeProvider,
|
||||
JavascriptCodeProvider,
|
||||
)
|
||||
_limits: CodeNodeLimits
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
code_executor: type[CodeExecutor] | None = None,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
|
||||
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
|
||||
tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
|
||||
)
|
||||
self._limits = code_limits
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@ -35,11 +67,16 @@ class CodeNode(Node[CodeNodeData]):
|
||||
if filters:
|
||||
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
|
||||
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
|
||||
code_provider: type[CodeNodeProvider] = next(
|
||||
provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language)
|
||||
)
|
||||
|
||||
return code_provider.get_default_config()
|
||||
|
||||
@classmethod
|
||||
def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]:
|
||||
return cls._DEFAULT_CODE_PROVIDERS
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@ -60,7 +97,8 @@ class CodeNode(Node[CodeNodeData]):
|
||||
variables[variable_name] = variable.to_object() if variable else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
_ = self._select_code_provider(code_language)
|
||||
result = self._code_executor.execute_workflow_code_template(
|
||||
language=code_language,
|
||||
code=code,
|
||||
inputs=variables,
|
||||
@ -75,6 +113,12 @@ class CodeNode(Node[CodeNodeData]):
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
||||
def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]:
|
||||
for provider in self._code_providers:
|
||||
if provider.is_accept_language(code_language):
|
||||
return provider
|
||||
raise CodeNodeError(f"Unsupported code language: {code_language}")
|
||||
|
||||
def _check_string(self, value: str | None, variable: str) -> str | None:
|
||||
"""
|
||||
Check string
|
||||
@ -85,10 +129,10 @@ class CodeNode(Node[CodeNodeData]):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
if len(value) > self._limits.max_string_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{variable}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
|
||||
f" less than {self._limits.max_string_length} characters"
|
||||
)
|
||||
|
||||
return value.replace("\x00", "")
|
||||
@ -109,20 +153,20 @@ class CodeNode(Node[CodeNodeData]):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
if value > self._limits.max_number or value < self._limits.min_number:
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` is out of range,"
|
||||
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
|
||||
f" it must be between {self._limits.min_number} and {self._limits.max_number}."
|
||||
)
|
||||
|
||||
if isinstance(value, float):
|
||||
decimal_value = Decimal(str(value)).normalize()
|
||||
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
|
||||
# raise error if precision is too high
|
||||
if precision > dify_config.CODE_MAX_PRECISION:
|
||||
if precision > self._limits.max_precision:
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` has too high precision,"
|
||||
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
|
||||
f" it must be less than {self._limits.max_precision} digits."
|
||||
)
|
||||
|
||||
return value
|
||||
@ -137,8 +181,8 @@ class CodeNode(Node[CodeNodeData]):
|
||||
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
|
||||
# Note that `_transform_result` may produce lists containing `None` values,
|
||||
# which don't conform to the type requirements of `Array*Segment` classes.
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
if depth > self._limits.max_depth:
|
||||
raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.")
|
||||
|
||||
transformed_result: dict[str, Any] = {}
|
||||
if output_schema is None:
|
||||
@ -272,10 +316,10 @@ class CodeNode(Node[CodeNodeData]):
|
||||
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
|
||||
if len(value) > self._limits.max_number_array_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
|
||||
f" less than {self._limits.max_number_array_length} elements."
|
||||
)
|
||||
|
||||
for i, inner_value in enumerate(value):
|
||||
@ -305,10 +349,10 @@ class CodeNode(Node[CodeNodeData]):
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
|
||||
if len(result[output_name]) > self._limits.max_string_array_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
|
||||
f" less than {self._limits.max_string_array_length} elements."
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
@ -326,10 +370,10 @@ class CodeNode(Node[CodeNodeData]):
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
|
||||
if len(result[output_name]) > self._limits.max_object_array_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
|
||||
f" less than {self._limits.max_object_array_length} elements."
|
||||
)
|
||||
|
||||
for i, value in enumerate(result[output_name]):
|
||||
|
||||
13
api/core/workflow/nodes/code/limits.py
Normal file
13
api/core/workflow/nodes/code/limits.py
Normal file
@ -0,0 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CodeNodeLimits:
|
||||
max_string_length: int
|
||||
max_number: int | float
|
||||
min_number: int | float
|
||||
max_precision: int
|
||||
max_depth: int
|
||||
max_number_array_length: int
|
||||
max_string_array_length: int
|
||||
max_object_array_length: int
|
||||
@ -301,7 +301,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
json: list[dict | list] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from core.helper import ssrf_proxy
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from ..protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from .entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeData,
|
||||
@ -78,6 +79,8 @@ class Executor:
|
||||
timeout: HttpRequestNodeTimeout,
|
||||
variable_pool: VariablePool,
|
||||
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
http_client: HttpClientProtocol = ssrf_proxy,
|
||||
file_manager: FileManagerProtocol = file_manager,
|
||||
):
|
||||
# If authorization API key is present, convert the API key using the variable pool
|
||||
if node_data.authorization.type == "api-key":
|
||||
@ -104,6 +107,8 @@ class Executor:
|
||||
self.data = None
|
||||
self.json = None
|
||||
self.max_retries = max_retries
|
||||
self._http_client = http_client
|
||||
self._file_manager = file_manager
|
||||
|
||||
# init template
|
||||
self.variable_pool = variable_pool
|
||||
@ -200,7 +205,7 @@ class Executor:
|
||||
if file_variable is None:
|
||||
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
|
||||
file = file_variable.value
|
||||
self.content = file_manager.download(file)
|
||||
self.content = self._file_manager.download(file)
|
||||
case "x-www-form-urlencoded":
|
||||
form_data = {
|
||||
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
|
||||
@ -239,7 +244,7 @@ class Executor:
|
||||
):
|
||||
file_tuple = (
|
||||
file.filename,
|
||||
file_manager.download(file),
|
||||
self._file_manager.download(file),
|
||||
file.mime_type or "application/octet-stream",
|
||||
)
|
||||
if key not in files:
|
||||
@ -332,19 +337,18 @@ class Executor:
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
_METHOD_MAP = {
|
||||
"get": ssrf_proxy.get,
|
||||
"head": ssrf_proxy.head,
|
||||
"post": ssrf_proxy.post,
|
||||
"put": ssrf_proxy.put,
|
||||
"delete": ssrf_proxy.delete,
|
||||
"patch": ssrf_proxy.patch,
|
||||
"get": self._http_client.get,
|
||||
"head": self._http_client.head,
|
||||
"post": self._http_client.post,
|
||||
"put": self._http_client.put,
|
||||
"delete": self._http_client.delete,
|
||||
"patch": self._http_client.patch,
|
||||
}
|
||||
method_lc = self.method.lower()
|
||||
if method_lc not in _METHOD_MAP:
|
||||
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
||||
|
||||
request_args = {
|
||||
"url": self.url,
|
||||
"data": self.data,
|
||||
"files": self.files,
|
||||
"json": self.json,
|
||||
@ -357,8 +361,12 @@ class Executor:
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries)
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
response: httpx.Response = _METHOD_MAP[method_lc](
|
||||
url=self.url,
|
||||
**request_args,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
|
||||
raise HttpRequestNodeError(str(e)) from e
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
return response
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
@ -13,6 +14,7 @@ from core.workflow.nodes.base import variable_template_parser
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from factories import file_factory
|
||||
|
||||
from .entities import (
|
||||
@ -30,10 +32,35 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
node_type = NodeType.HTTP_REQUEST
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
http_client: HttpClientProtocol = ssrf_proxy,
|
||||
tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
file_manager: FileManagerProtocol = file_manager,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._http_client = http_client
|
||||
self._tool_file_manager_factory = tool_file_manager_factory
|
||||
self._file_manager = file_manager
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
@ -71,6 +98,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
max_retries=0,
|
||||
http_client=self._http_client,
|
||||
file_manager=self._file_manager,
|
||||
)
|
||||
process_data["request"] = http_executor.to_log()
|
||||
|
||||
@ -199,7 +228,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
mime_type = (
|
||||
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
)
|
||||
tool_file_manager = ToolFileManager()
|
||||
tool_file_manager = self._tool_file_manager_factory()
|
||||
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
|
||||
@ -1,17 +1,15 @@
|
||||
import contextvars
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import (
|
||||
NodeExecutionType,
|
||||
@ -39,7 +37,6 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .exc import (
|
||||
InvalidIteratorValueError,
|
||||
@ -51,6 +48,7 @@ from .exc import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.context import IExecutionContext
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -240,7 +238,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
datetime,
|
||||
list[GraphNodeEventBase],
|
||||
object | None,
|
||||
dict[str, VariableUnion],
|
||||
dict[str, Variable],
|
||||
LLMUsage,
|
||||
]
|
||||
],
|
||||
@ -252,8 +250,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
self._execute_single_iteration_parallel,
|
||||
index=index,
|
||||
item=item,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context_vars=contextvars.copy_context(),
|
||||
execution_context=self._capture_execution_context(),
|
||||
)
|
||||
future_to_index[future] = index
|
||||
|
||||
@ -306,11 +303,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
self,
|
||||
index: int,
|
||||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
|
||||
execution_context: "IExecutionContext",
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
with execution_context:
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
events: list[GraphNodeEventBase] = []
|
||||
outputs_temp: list[object] = []
|
||||
@ -339,6 +335,12 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
graph_engine.graph_runtime_state.llm_usage,
|
||||
)
|
||||
|
||||
def _capture_execution_context(self) -> "IExecutionContext":
|
||||
"""Capture current execution context for parallel iterations."""
|
||||
from core.workflow.context import capture_current_context
|
||||
|
||||
return capture_current_context()
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
started_at: datetime,
|
||||
@ -515,11 +517,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
|
||||
return variable_mapping
|
||||
|
||||
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
|
||||
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
|
||||
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
|
||||
|
||||
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
|
||||
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
|
||||
parent_pool = self.graph_runtime_state.variable_pool
|
||||
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
|
||||
@ -586,11 +588,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
|
||||
def _create_graph_engine(self, index: int, item: object):
|
||||
# Import dependencies
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
|
||||
@ -6,7 +6,7 @@ from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
@ -460,7 +460,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
self._process_metadata_filter_func(
|
||||
DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
@ -504,7 +504,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = self._process_metadata_filter_func(
|
||||
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
@ -603,87 +603,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
return [], usage
|
||||
return automatic_metadata_filters, usage
|
||||
|
||||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
||||
) -> list[Any]:
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return filters
|
||||
|
||||
json_field = Document.doc_metadata[metadata_name].as_string()
|
||||
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
|
||||
case "start with":
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
|
||||
case "end with":
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
case "in":
|
||||
if isinstance(value, str):
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(False))
|
||||
else:
|
||||
filters.append(json_field.in_(value_list))
|
||||
|
||||
case "not in":
|
||||
if isinstance(value, str):
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(True))
|
||||
else:
|
||||
filters.append(json_field.notin_(value_list))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
filters.append(json_field == value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
|
||||
|
||||
case "is not" | "≠":
|
||||
if isinstance(value, str):
|
||||
filters.append(json_field != value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
|
||||
|
||||
case "empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].is_(None))
|
||||
|
||||
case "not empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].isnot(None))
|
||||
|
||||
case "before" | "<":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
|
||||
|
||||
case "after" | ">":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
|
||||
|
||||
case "≤" | "<=":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
|
||||
|
||||
case "≥" | ">=":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
return filters
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
||||
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.file.models import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
@ -136,21 +136,37 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
@ -113,7 +115,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# Instance attributes specific to LLMNode.
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
_file_outputs: list[File]
|
||||
|
||||
_llm_file_saver: LLMFileSaver
|
||||
|
||||
@ -121,8 +123,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@ -361,7 +363,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
structured_output_enabled: bool,
|
||||
structured_output: Mapping[str, Any] | None = None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
file_outputs: list[File],
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
@ -415,7 +417,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
*,
|
||||
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
file_outputs: list[File],
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
@ -525,7 +527,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _image_file_to_markdown(file: "File", /):
|
||||
def _image_file_to_markdown(file: File, /):
|
||||
text_chunk = f"})"
|
||||
return text_chunk
|
||||
|
||||
@ -774,7 +776,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def fetch_prompt_messages(
|
||||
*,
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence["File"],
|
||||
sys_files: Sequence[File],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
@ -785,7 +787,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
tenant_id: str,
|
||||
context_files: list["File"] | None = None,
|
||||
context_files: list[File] | None = None,
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
@ -1137,7 +1139,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
*,
|
||||
invoke_result: LLMResult | LLMResultWithStructuredOutput,
|
||||
saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
file_outputs: list[File],
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
request_latency: float | None = None,
|
||||
) -> ModelInvokeCompletedEvent:
|
||||
@ -1179,7 +1181,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
*,
|
||||
content: ImagePromptMessageContent,
|
||||
file_saver: LLMFileSaver,
|
||||
) -> "File":
|
||||
) -> File:
|
||||
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
|
||||
|
||||
There are two kinds of multimodal outputs:
|
||||
@ -1229,7 +1231,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
*,
|
||||
contents: str | list[PromptMessageContentUnionTypes] | None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
file_outputs: list[File],
|
||||
) -> Generator[str, None, None]:
|
||||
"""Convert intermediate prompt messages into strings and yield them to the caller.
|
||||
|
||||
|
||||
@ -413,11 +413,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
|
||||
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
||||
# Import dependencies
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
|
||||
@ -1,80 +0,0 @@
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
@final
|
||||
class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
Default implementation of NodeFactory that uses the traditional node mapping.
|
||||
|
||||
This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
|
||||
and instantiating the appropriate node class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data using the traditional mapping.
|
||||
|
||||
:param node_config: node configuration dictionary containing type and other data
|
||||
:return: initialized Node instance
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
"""
|
||||
# Get node_id from config
|
||||
node_id = node_config.get("id")
|
||||
if not is_str(node_id):
|
||||
raise ValueError("Node config missing id")
|
||||
|
||||
# Get node type from config
|
||||
node_data = node_config.get("data", {})
|
||||
if not is_str_dict(node_data):
|
||||
raise ValueError(f"Node {node_id} missing data information")
|
||||
|
||||
node_type_str = node_data.get("type")
|
||||
if not is_str(node_type_str):
|
||||
raise ValueError(f"Node {node_id} missing or invalid type information")
|
||||
|
||||
try:
|
||||
node_type = NodeType(node_type_str)
|
||||
except ValueError:
|
||||
raise ValueError(f"Unknown node type: {node_type_str}")
|
||||
|
||||
# Get node class
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
latest_node_class = node_mapping.get(LATEST_VERSION)
|
||||
node_version = str(node_data.get("version", "1"))
|
||||
matched_node_class = node_mapping.get(node_version)
|
||||
node_class = matched_node_class or latest_node_class
|
||||
if not node_class:
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
|
||||
# Create node instance
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
29
api/core/workflow/nodes/protocols.py
Normal file
29
api/core/workflow/nodes/protocols.py
Normal file
@ -0,0 +1,29 @@
|
||||
from typing import Protocol
|
||||
|
||||
import httpx
|
||||
|
||||
from core.file import File
|
||||
|
||||
|
||||
class HttpClientProtocol(Protocol):
|
||||
@property
|
||||
def max_retries_exceeded_error(self) -> type[Exception]: ...
|
||||
|
||||
@property
|
||||
def request_error(self) -> type[Exception]: ...
|
||||
|
||||
def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
|
||||
def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
|
||||
def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
|
||||
def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
|
||||
def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
|
||||
def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
|
||||
|
||||
class FileManagerProtocol(Protocol):
|
||||
def download(self, f: File, /) -> bytes: ...
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from jsonschema import Draft7Validator, ValidationError
|
||||
@ -43,25 +42,22 @@ class StartNode(Node[StartNodeData]):
|
||||
if value is None and variable.required:
|
||||
raise ValueError(f"{key} is required in input form")
|
||||
|
||||
# If no value provided, skip further processing for this key
|
||||
if not value:
|
||||
continue
|
||||
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"JSON object for '{key}' must be an object")
|
||||
|
||||
# Overwrite with normalized dict to ensure downstream consistency
|
||||
node_inputs[key] = value
|
||||
|
||||
# If schema exists, then validate against it
|
||||
schema = variable.json_schema
|
||||
if not schema:
|
||||
continue
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
||||
try:
|
||||
json_schema = json.loads(schema)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{schema} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
json_value = json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{value} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
Draft7Validator(json_schema).validate(json_value)
|
||||
Draft7Validator(schema).validate(value)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
|
||||
node_inputs[key] = json_value
|
||||
|
||||
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
|
||||
|
||||
class TemplateRenderError(ValueError):
|
||||
"""Raised when rendering a Jinja2 template fails."""
|
||||
|
||||
|
||||
class Jinja2TemplateRenderer(Protocol):
|
||||
"""Render Jinja2 templates for template transform nodes."""
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
"""Render a Jinja2 template with provided variables."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
|
||||
"""Adapter that renders Jinja2 templates via CodeExecutor."""
|
||||
|
||||
_code_executor: type[CodeExecutor]
|
||||
|
||||
def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
|
||||
self._code_executor = code_executor or CodeExecutor
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
try:
|
||||
result = self._code_executor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as exc:
|
||||
raise TemplateRenderError(str(exc)) from exc
|
||||
|
||||
rendered = result.get("result")
|
||||
if not isinstance(rendered, str):
|
||||
raise TemplateRenderError("Template render result must be a string.")
|
||||
return rendered
|
||||
@ -1,18 +1,44 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
TemplateRenderError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
|
||||
|
||||
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
_template_renderer: Jinja2TemplateRenderer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
variables[variable_name] = value.to_object() if value else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as e:
|
||||
rendered = self._template_renderer.render_template(self.node_data.template, variables)
|
||||
except TemplateRenderError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -244,7 +244,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
json: list[dict | list] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
@ -400,7 +400,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
message.message.metadata = dict_metadata
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json:
|
||||
|
||||
@ -1,28 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables.variables import Variable
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable
|
||||
|
||||
from .exc import VariableOperatorNodeError
|
||||
|
||||
|
||||
class ConversationVariableUpdaterImpl:
|
||||
def update(self, conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
|
||||
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
|
||||
return ConversationVariableUpdaterImpl()
|
||||
@ -1,9 +1,8 @@
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables import SegmentType, VariableBase
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
@ -11,19 +10,14 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
|
||||
from ..common.impl import conversation_variable_updater_factory
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
|
||||
class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
node_type = NodeType.VARIABLE_ASSIGNER
|
||||
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -31,7 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@ -39,7 +32,15 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._conv_var_updater_factory = conv_var_updater_factory
|
||||
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||
"""
|
||||
Check if this Variable Assigner node blocks the output of specific variables.
|
||||
|
||||
Returns True if this node updates any of the requested conversation variables.
|
||||
"""
|
||||
assigned_selector = tuple(self.node_data.assigned_variable_selector)
|
||||
return assigned_selector in variable_selectors
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
@ -72,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
assigned_variable_selector = self.node_data.assigned_variable_selector
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
if not isinstance(original_variable, VariableBase):
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
|
||||
match self.node_data.write_mode:
|
||||
@ -96,16 +97,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||
|
||||
# TODO: Move database operation to the pipeline.
|
||||
# Update conversation variable.
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
raise VariableOperatorNodeError("conversation_id not found")
|
||||
conv_var_updater = self._conv_var_updater_factory()
|
||||
conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
conv_var_updater.flush()
|
||||
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
|
||||
@ -1,24 +1,20 @@
|
||||
import json
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables import SegmentType, VariableBase
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||
|
||||
from . import helpers
|
||||
from .entities import VariableAssignerNodeData, VariableOperationItem
|
||||
from .enums import InputType, Operation
|
||||
from .exc import (
|
||||
ConversationIDNotFoundError,
|
||||
InputTypeNotSupportedError,
|
||||
InvalidDataError,
|
||||
InvalidInputValueError,
|
||||
@ -26,6 +22,10 @@ from .exc import (
|
||||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
|
||||
selector_node_id = item.variable_selector[0]
|
||||
@ -53,6 +53,20 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
|
||||
class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||
"""
|
||||
Check if this Variable Assigner node blocks the output of specific variables.
|
||||
@ -70,9 +84,6 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
|
||||
return False
|
||||
|
||||
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
|
||||
return conversation_variable_updater_factory()
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "2"
|
||||
@ -107,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
# ==================== Validation Part
|
||||
|
||||
# Check if variable exists
|
||||
if not isinstance(variable, Variable):
|
||||
if not isinstance(variable, VariableBase):
|
||||
raise VariableNotFoundError(variable_selector=item.variable_selector)
|
||||
|
||||
# Check if operation is supported
|
||||
@ -179,26 +190,12 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
# remove the duplicated items first.
|
||||
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
|
||||
|
||||
conv_var_updater = self._conv_var_updater_factory()
|
||||
# Update variables
|
||||
for selector in updated_variable_selectors:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, Variable):
|
||||
if not isinstance(variable, VariableBase):
|
||||
raise VariableNotFoundError(variable_selector=selector)
|
||||
process_data[variable.name] = variable.value
|
||||
|
||||
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
if self.invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ConversationIDNotFoundError
|
||||
else:
|
||||
conversation_id = conversation_id.value
|
||||
conv_var_updater.update(
|
||||
conversation_id=cast(str, conversation_id),
|
||||
variable=variable,
|
||||
)
|
||||
conv_var_updater.flush()
|
||||
updated_variables = [
|
||||
common_helpers.variable_to_processed_data(selector, seg)
|
||||
for selector in updated_variable_selectors
|
||||
@ -216,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
def _handle_item(
|
||||
self,
|
||||
*,
|
||||
variable: Variable,
|
||||
variable: VariableBase,
|
||||
operation: Operation,
|
||||
value: Any,
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user