mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
feat: Persist Variables for Enhanced Debugging Workflow (#20699)
This pull request introduces a feature aimed at improving the debugging experience during workflow editing. With the addition of variable persistence, the system will automatically retain the output variables from previously executed nodes. These persisted variables can then be reused when debugging subsequent nodes, eliminating the need for repetitive manual input. By streamlining this aspect of the workflow, the feature minimizes user errors and significantly reduces debugging effort, offering a smoother and more efficient experience. Key highlights of this change: - Automatic persistence of output variables for executed nodes. - Reuse of persisted variables to simplify input steps for nodes requiring them (e.g., `code`, `template`, `variable_assigner`). - Enhanced debugging experience with reduced friction. Closes #19735.
This commit is contained in:
39
api/core/workflow/conversation_variable_updater.py
Normal file
39
api/core/workflow/conversation_variable_updater.py
Normal file
@ -0,0 +1,39 @@
|
||||
import abc
|
||||
from typing import Protocol
|
||||
|
||||
from core.variables import Variable
|
||||
|
||||
|
||||
class ConversationVariableUpdater(Protocol):
|
||||
"""
|
||||
ConversationVariableUpdater defines an abstraction for updating conversation variable values.
|
||||
|
||||
It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating
|
||||
conversation variables.
|
||||
|
||||
Implementations may choose to batch updates. If batching is used, the `flush` method
|
||||
should be implemented to persist buffered changes, and `update`
|
||||
should handle buffering accordingly.
|
||||
|
||||
Note: Since implementations may buffer updates, instances of ConversationVariableUpdater
|
||||
are not thread-safe. Each VariableAssignerNode should create its own instance during execution.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, conversation_id: str, variable: "Variable") -> None:
|
||||
"""
|
||||
Updates the value of the specified conversation variable in the underlying storage.
|
||||
|
||||
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
|
||||
:param variable: The `Variable` instance containing the updated value.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def flush(self):
|
||||
"""
|
||||
Flushes all pending updates to the underlying storage system.
|
||||
|
||||
If the implementation does not buffer updates, this method can be a no-op.
|
||||
"""
|
||||
pass
|
||||
@ -7,12 +7,12 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, NoneSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from factories import variable_factory
|
||||
|
||||
from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from ..enums import SystemVariableKey
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, File]
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
||||
@ -30,9 +30,11 @@ class VariablePool(BaseModel):
|
||||
# TODO: This user inputs is not used for pool.
|
||||
user_inputs: Mapping[str, Any] = Field(
|
||||
description="User inputs",
|
||||
default_factory=dict,
|
||||
)
|
||||
system_variables: Mapping[SystemVariableKey, Any] = Field(
|
||||
description="System variables",
|
||||
default_factory=dict,
|
||||
)
|
||||
environment_variables: Sequence[Variable] = Field(
|
||||
description="Environment variables.",
|
||||
@ -43,28 +45,7 @@ class VariablePool(BaseModel):
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
system_variables: Mapping[SystemVariableKey, Any] | None = None,
|
||||
user_inputs: Mapping[str, Any] | None = None,
|
||||
environment_variables: Sequence[Variable] | None = None,
|
||||
conversation_variables: Sequence[Variable] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
environment_variables = environment_variables or []
|
||||
conversation_variables = conversation_variables or []
|
||||
user_inputs = user_inputs or {}
|
||||
system_variables = system_variables or {}
|
||||
|
||||
super().__init__(
|
||||
system_variables=system_variables,
|
||||
user_inputs=user_inputs,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def model_post_init(self, context: Any, /) -> None:
|
||||
for key, value in self.system_variables.items():
|
||||
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
|
||||
# Add environment variables to the variable pool
|
||||
@ -91,12 +72,12 @@ class VariablePool(BaseModel):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
raise ValueError("Invalid selector")
|
||||
|
||||
if isinstance(value, Variable):
|
||||
variable = value
|
||||
if isinstance(value, Segment):
|
||||
elif isinstance(value, Segment):
|
||||
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
|
||||
else:
|
||||
segment = variable_factory.build_segment(value)
|
||||
@ -118,7 +99,7 @@ class VariablePool(BaseModel):
|
||||
Raises:
|
||||
ValueError: If the selector is invalid.
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
return None
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
|
||||
@ -66,6 +66,8 @@ class BaseNodeEvent(GraphEngineEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
# The version of the node, or "1" if not specified.
|
||||
node_version: str = "1"
|
||||
|
||||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
|
||||
@ -314,6 +314,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
raise e
|
||||
|
||||
@ -627,6 +628,7 @@ class GraphEngine:
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
agent_strategy=agent_strategy,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
|
||||
max_retries = node_instance.node_data.retry_config.max_retries
|
||||
@ -677,6 +679,7 @@ class GraphEngine:
|
||||
error=run_result.error or "Unknown error",
|
||||
retry_index=retries,
|
||||
start_at=retry_start_at,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
time.sleep(retry_interval)
|
||||
break
|
||||
@ -712,6 +715,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
else:
|
||||
@ -726,6 +730,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
@ -786,6 +791,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
|
||||
@ -803,6 +809,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
elif isinstance(event, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
@ -817,6 +824,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
@ -833,6 +841,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
|
||||
@ -18,7 +18,11 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
class AnswerNode(BaseNode[AnswerNodeData]):
|
||||
_node_data_cls = AnswerNodeData
|
||||
_node_type: NodeType = NodeType.ANSWER
|
||||
_node_type = NodeType.ANSWER
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
@ -45,7 +49,10 @@ class AnswerNode(BaseNode[AnswerNodeData]):
|
||||
part = cast(TextGenerateRouteChunk, part)
|
||||
answer += part.text
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files})
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"answer": answer, "files": ArrayFileSegment(value=files)},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
|
||||
@ -109,6 +109,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
from_variable_selector=[answer_node_id, "answer"],
|
||||
node_version=event.node_version,
|
||||
)
|
||||
else:
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
@ -134,6 +135,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
node_version=event.node_version,
|
||||
)
|
||||
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
@ -23,7 +23,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
|
||||
|
||||
class BaseNode(Generic[GenericNodeData]):
|
||||
_node_data_cls: type[GenericNodeData]
|
||||
_node_type: NodeType
|
||||
_node_type: ClassVar[NodeType]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -90,8 +90,38 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
graph_config: Mapping[str, Any],
|
||||
config: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
"""Extracts references variable selectors from node configuration.
|
||||
|
||||
The `config` parameter represents the configuration for a specific node type and corresponds
|
||||
to the `data` field in the node definition object.
|
||||
|
||||
The returned mapping has the following structure:
|
||||
|
||||
{'1747829548239.#1747829667553.result#': ['1747829667553', 'result']}
|
||||
|
||||
For loop and iteration nodes, the mapping may look like this:
|
||||
|
||||
{
|
||||
"1748332301644.input_selector": ["1748332363630", "result"],
|
||||
"1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"],
|
||||
}
|
||||
|
||||
where `1748332301644` is the ID of the loop / iteration node,
|
||||
and `1748332325079` is the ID of the node inside the loop or iteration node.
|
||||
|
||||
Here, the key consists of two parts: the current node ID (provided as the `node_id`
|
||||
parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector,
|
||||
enclosed in `#` symbols. These two parts are separated by a dot (`.`).
|
||||
|
||||
The value is a list of string representing the variable selector, where the first element is the node ID
|
||||
of the referenced variable, and the second element is the variable name within that node.
|
||||
|
||||
The meaning of the above response is:
|
||||
|
||||
The node with ID `1747829548239` references the variable `result` from the node with
|
||||
ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a
|
||||
reference to the `result` output variable of node `1747829667553`.
|
||||
|
||||
:param graph_config: graph config
|
||||
:param config: node config
|
||||
:return:
|
||||
@ -101,9 +131,10 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
return cls._extract_variable_selector_to_variable_mapping(
|
||||
data = cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
|
||||
)
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
@ -139,6 +170,16 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
"""
|
||||
return self._node_type
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def version(cls) -> str:
|
||||
"""`node_version` returns the version of current node type."""
|
||||
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
|
||||
#
|
||||
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
|
||||
# in `api/core/workflow/nodes/__init__.py`.
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@property
|
||||
def should_continue_on_error(self) -> bool:
|
||||
"""judge if should continue on error
|
||||
|
||||
@ -40,6 +40,10 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
|
||||
return code_provider.get_default_config()
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get code language
|
||||
code_language = self.node_data.code_language
|
||||
@ -126,6 +130,9 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
prefix: str = "",
|
||||
depth: int = 1,
|
||||
):
|
||||
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
|
||||
# Note that `_transform_result` may produce lists containing `None` values,
|
||||
# which don't conform to the type requirements of `Array*Segment` classes.
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import FileSegment
|
||||
from core.variables.segments import ArrayStringSegment, FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
@ -45,6 +45,10 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
||||
_node_data_cls = DocumentExtractorNodeData
|
||||
_node_type = NodeType.DOCUMENT_EXTRACTOR
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
variable_selector = self.node_data.variable_selector
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
|
||||
@ -67,7 +71,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"text": extracted_text_list},
|
||||
outputs={"text": ArrayStringSegment(value=extracted_text_list)},
|
||||
)
|
||||
elif isinstance(value, File):
|
||||
extracted_text = _extract_text_from_file(value)
|
||||
|
||||
@ -9,6 +9,10 @@ class EndNode(BaseNode[EndNodeData]):
|
||||
_node_data_cls = EndNodeData
|
||||
_node_type = NodeType.END
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
|
||||
@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
node_version=event.node_version,
|
||||
)
|
||||
|
||||
self.route_position[end_node_id] += 1
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Any, Optional
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
@ -60,6 +61,10 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
process_data = {}
|
||||
try:
|
||||
@ -92,7 +97,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"status_code": response.status_code,
|
||||
"body": response.text if not files else "",
|
||||
"body": response.text if not files.value else "",
|
||||
"headers": response.headers,
|
||||
"files": files,
|
||||
},
|
||||
@ -166,7 +171,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
|
||||
return mapping
|
||||
|
||||
def extract_files(self, url: str, response: Response) -> list[File]:
|
||||
def extract_files(self, url: str, response: Response) -> ArrayFileSegment:
|
||||
"""
|
||||
Extract files from response by checking both Content-Type header and URL
|
||||
"""
|
||||
@ -178,7 +183,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
content_disposition_type = None
|
||||
|
||||
if not is_file:
|
||||
return files
|
||||
return ArrayFileSegment(value=[])
|
||||
|
||||
if parsed_content_disposition:
|
||||
content_disposition_filename = parsed_content_disposition.get_filename()
|
||||
@ -211,4 +216,4 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
return files
|
||||
return ArrayFileSegment(value=files)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import Literal
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@ -16,6 +17,10 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
_node_data_cls = IfElseNodeData
|
||||
_node_type = NodeType.IF_ELSE
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
@ -87,6 +92,22 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IfElseNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
var_mapping: dict[str, list[str]] = {}
|
||||
for case in node_data.cases or []:
|
||||
for condition in case.conditions:
|
||||
key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector))
|
||||
var_mapping[key] = condition.variable_selector
|
||||
|
||||
return var_mapping
|
||||
|
||||
|
||||
@deprecated("This function is deprecated. You should use the new cases structure.")
|
||||
def _should_not_use_old_function(
|
||||
|
||||
@ -11,6 +11,7 @@ from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.workflow.entities.node_entities import (
|
||||
NodeRunResult,
|
||||
)
|
||||
@ -37,6 +38,7 @@ from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from factories.variable_factory import build_segment
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .exc import (
|
||||
@ -72,6 +74,10 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run the node.
|
||||
@ -85,10 +91,17 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
|
||||
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
|
||||
# Try our best to preserve the type informat.
|
||||
if isinstance(variable, ArraySegment):
|
||||
output = variable.model_copy(update={"value": []})
|
||||
else:
|
||||
output = ArrayAnySegment(value=[])
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": []},
|
||||
# TODO(QuantumGhost): is it possible to compute the type of `output`
|
||||
# from graph definition?
|
||||
outputs={"output": output},
|
||||
)
|
||||
)
|
||||
return
|
||||
@ -231,6 +244,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
# Flatten the list of lists
|
||||
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
|
||||
outputs = [item for sublist in outputs for item in sublist]
|
||||
output_segment = build_segment(outputs)
|
||||
|
||||
yield IterationRunSucceededEvent(
|
||||
iteration_id=self.id,
|
||||
@ -247,7 +261,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": outputs},
|
||||
outputs={"output": output_segment},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
|
||||
@ -13,6 +13,10 @@ class IterationStartNode(BaseNode[IterationStartNodeData]):
|
||||
_node_data_cls = IterationStartNodeData
|
||||
_node_type = NodeType.ITERATION_START
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
|
||||
@ -24,6 +24,7 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.variables import StringSegment
|
||||
from core.variables.segments import ArrayObjectSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
@ -115,9 +116,12 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
# retrieve knowledge
|
||||
try:
|
||||
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
|
||||
outputs = {"result": results}
|
||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
process_data=None,
|
||||
outputs=outputs, # type: ignore
|
||||
)
|
||||
|
||||
except KnowledgeRetrievalNodeError as e:
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any, Literal, Union
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
@ -16,6 +17,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
_node_data_cls = ListOperatorNodeData
|
||||
_node_type = NodeType.LIST_OPERATOR
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
inputs: dict[str, list] = {}
|
||||
process_data: dict[str, list] = {}
|
||||
@ -30,7 +35,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
if not variable.value:
|
||||
inputs = {"variable": []}
|
||||
process_data = {"variable": []}
|
||||
outputs = {"result": [], "first_record": None, "last_record": None}
|
||||
if isinstance(variable, ArraySegment):
|
||||
result = variable.model_copy(update={"value": []})
|
||||
else:
|
||||
result = ArrayAnySegment(value=[])
|
||||
outputs = {"result": result, "first_record": None, "last_record": None}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
@ -71,7 +80,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
variable = self._apply_slice(variable)
|
||||
|
||||
outputs = {
|
||||
"result": variable.value,
|
||||
"result": variable,
|
||||
"first_record": variable.value[0] if variable.value else None,
|
||||
"last_record": variable.value[-1] if variable.value else None,
|
||||
}
|
||||
|
||||
@ -119,9 +119,6 @@ class FileSaverImpl(LLMFileSaver):
|
||||
size=len(data),
|
||||
related_id=tool_file.id,
|
||||
url=url,
|
||||
# TODO(QuantumGhost): how should I set the following key?
|
||||
# What's the difference between `remote_url` and `url`?
|
||||
# What's the purpose of `storage_key` and `dify_model_identity`?
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
|
||||
@ -138,6 +138,10 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def process_structured_output(text: str) -> Optional[dict[str, Any]]:
|
||||
"""Process structured output if enabled"""
|
||||
@ -255,7 +259,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output
|
||||
if self._file_outputs is not None:
|
||||
outputs["files"] = self._file_outputs
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
|
||||
@ -13,6 +13,10 @@ class LoopEndNode(BaseNode[LoopEndNodeData]):
|
||||
_node_data_cls = LoopEndNodeData
|
||||
_node_type = NodeType.LOOP_END
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
|
||||
@ -54,6 +54,10 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
_node_data_cls = LoopNodeData
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""Run the node."""
|
||||
# Get inputs
|
||||
@ -482,6 +486,13 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
for loop_variable in node_data.loop_variables or []:
|
||||
if loop_variable.value_type == "variable":
|
||||
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
|
||||
# add loop variable to variable mapping
|
||||
selector = loop_variable.value
|
||||
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
|
||||
|
||||
# remove variable out from loop
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
|
||||
|
||||
@ -13,6 +13,10 @@ class LoopStartNode(BaseNode[LoopStartNodeData]):
|
||||
_node_data_cls = LoopStartNodeData
|
||||
_node_type = NodeType.LOOP_START
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
|
||||
@ -25,6 +25,11 @@ from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as Var
|
||||
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
|
||||
# Specifically, if you have introduced new node types, you should add them here.
|
||||
#
|
||||
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
|
||||
# hook. Try to avoid duplication of node information.
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
NodeType.START: {
|
||||
LATEST_VERSION: StartNode,
|
||||
|
||||
@ -7,6 +7,10 @@ from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
class _ParameterConfigError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ParameterConfig(BaseModel):
|
||||
"""
|
||||
Parameter Config.
|
||||
@ -27,6 +31,19 @@ class ParameterConfig(BaseModel):
|
||||
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
|
||||
return str(value)
|
||||
|
||||
def is_array_type(self) -> bool:
|
||||
return self.type in ("array[string]", "array[number]", "array[object]")
|
||||
|
||||
def element_type(self) -> Literal["string", "number", "object"]:
|
||||
if self.type == "array[number]":
|
||||
return "number"
|
||||
elif self.type == "array[string]":
|
||||
return "string"
|
||||
elif self.type == "array[object]":
|
||||
return "object"
|
||||
else:
|
||||
raise _ParameterConfigError(f"{self.type} is not array type.")
|
||||
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
"""
|
||||
|
||||
@ -25,6 +25,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
@ -32,6 +33,7 @@ from core.workflow.nodes.base.node import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
from .exc import (
|
||||
@ -109,6 +111,10 @@ class ParameterExtractorNode(BaseNode):
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
"""
|
||||
Run the node.
|
||||
@ -584,28 +590,30 @@ class ParameterExtractorNode(BaseNode):
|
||||
elif parameter.type in {"string", "select"}:
|
||||
if isinstance(result[parameter.name], str):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
elif parameter.type.startswith("array"):
|
||||
elif parameter.is_array_type():
|
||||
if isinstance(result[parameter.name], list):
|
||||
nested_type = parameter.type[6:-1]
|
||||
transformed_result[parameter.name] = []
|
||||
nested_type = parameter.element_type()
|
||||
assert nested_type is not None
|
||||
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
|
||||
transformed_result[parameter.name] = segment_value
|
||||
for item in result[parameter.name]:
|
||||
if nested_type == "number":
|
||||
if isinstance(item, int | float):
|
||||
transformed_result[parameter.name].append(item)
|
||||
segment_value.value.append(item)
|
||||
elif isinstance(item, str):
|
||||
try:
|
||||
if "." in item:
|
||||
transformed_result[parameter.name].append(float(item))
|
||||
segment_value.value.append(float(item))
|
||||
else:
|
||||
transformed_result[parameter.name].append(int(item))
|
||||
segment_value.value.append(int(item))
|
||||
except ValueError:
|
||||
pass
|
||||
elif nested_type == "string":
|
||||
if isinstance(item, str):
|
||||
transformed_result[parameter.name].append(item)
|
||||
segment_value.value.append(item)
|
||||
elif nested_type == "object":
|
||||
if isinstance(item, dict):
|
||||
transformed_result[parameter.name].append(item)
|
||||
segment_value.value.append(item)
|
||||
|
||||
if parameter.name not in transformed_result:
|
||||
if parameter.type == "number":
|
||||
@ -615,7 +623,9 @@ class ParameterExtractorNode(BaseNode):
|
||||
elif parameter.type in {"string", "select"}:
|
||||
transformed_result[parameter.name] = ""
|
||||
elif parameter.type.startswith("array"):
|
||||
transformed_result[parameter.name] = []
|
||||
transformed_result[parameter.name] = build_segment_with_type(
|
||||
segment_type=SegmentType(parameter.type), value=[]
|
||||
)
|
||||
|
||||
return transformed_result
|
||||
|
||||
|
||||
@ -10,6 +10,10 @@ class StartNode(BaseNode[StartNodeData]):
|
||||
_node_data_cls = StartNodeData
|
||||
_node_type = NodeType.START
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables
|
||||
@ -18,5 +22,6 @@ class StartNode(BaseNode[StartNodeData]):
|
||||
# Set system variables as node outputs.
|
||||
for var in system_inputs:
|
||||
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
|
||||
outputs = dict(node_inputs)
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)
|
||||
|
||||
@ -28,6 +28,10 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
|
||||
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get variables
|
||||
variables = {}
|
||||
|
||||
@ -12,7 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@ -44,6 +44,10 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
_node_data_cls = ToolNodeData
|
||||
_node_type = NodeType.TOOL
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the tool node
|
||||
@ -300,6 +304,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, File)
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
@ -363,7 +368,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": files, "json": json, **variables},
|
||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
@ -9,16 +12,20 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
_node_type = NodeType.VARIABLE_AGGREGATOR
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get variables
|
||||
outputs = {}
|
||||
outputs: dict[str, Segment | Mapping[str, Segment]] = {}
|
||||
inputs = {}
|
||||
|
||||
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
|
||||
for selector in self.node_data.variables:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if variable is not None:
|
||||
outputs = {"output": variable.to_object()}
|
||||
outputs = {"output": variable}
|
||||
|
||||
inputs = {".".join(selector[1:]): variable.to_object()}
|
||||
break
|
||||
@ -28,7 +35,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
|
||||
if variable is not None:
|
||||
outputs[group.group_name] = {"output": variable.to_object()}
|
||||
outputs[group.group_name] = {"output": variable}
|
||||
inputs[".".join(selector[1:])] = variable.to_object()
|
||||
break
|
||||
|
||||
|
||||
@ -1,19 +1,55 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from core.variables import Variable
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.variables import Segment
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
# Use double underscore (`__`) prefix for internal variables
|
||||
# to minimize risk of collision with user-defined variable names.
|
||||
_UPDATED_VARIABLES_KEY = "__updated_variables"
|
||||
|
||||
|
||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
class UpdatedVariable(BaseModel):
|
||||
name: str
|
||||
selector: Sequence[str]
|
||||
value_type: SegmentType
|
||||
new_value: Any
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=MutableMapping[str, Any])
|
||||
|
||||
|
||||
def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
raise Exception("selector too short")
|
||||
node_id, var_name = selector[:2]
|
||||
return UpdatedVariable(
|
||||
name=var_name,
|
||||
selector=list(selector[:2]),
|
||||
value_type=seg.value_type,
|
||||
new_value=seg.value,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
|
||||
def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T:
|
||||
m[_UPDATED_VARIABLES_KEY] = updates
|
||||
return m
|
||||
|
||||
|
||||
def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None:
|
||||
updated_values = m.get(_UPDATED_VARIABLES_KEY, None)
|
||||
if updated_values is None:
|
||||
return None
|
||||
result = []
|
||||
for items in updated_values:
|
||||
if isinstance(items, UpdatedVariable):
|
||||
result.append(items)
|
||||
elif isinstance(items, dict):
|
||||
items = UpdatedVariable.model_validate(items)
|
||||
result.append(items)
|
||||
else:
|
||||
raise TypeError(f"Invalid updated variable: {items}, type={type(items)}")
|
||||
return result
|
||||
|
||||
38
api/core/workflow/nodes/variable_assigner/common/impl.py
Normal file
38
api/core/workflow/nodes/variable_assigner/common/impl.py
Normal file
@ -0,0 +1,38 @@
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables.variables import Variable
|
||||
from models.engine import db
|
||||
from models.workflow import ConversationVariable
|
||||
|
||||
from .exc import VariableOperatorNodeError
|
||||
|
||||
|
||||
class ConversationVariableUpdaterImpl:
|
||||
_engine: Engine | None
|
||||
|
||||
def __init__(self, engine: Engine | None = None) -> None:
|
||||
self._engine = engine
|
||||
|
||||
def _get_engine(self) -> Engine:
|
||||
if self._engine:
|
||||
return self._engine
|
||||
return db.engine
|
||||
|
||||
def update(self, conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(self._get_engine()) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
|
||||
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
|
||||
return ConversationVariableUpdaterImpl()
|
||||
@ -1,4 +1,9 @@
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
@ -7,16 +12,71 @@ from core.workflow.nodes.variable_assigner.common import helpers as common_helpe
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from factories import variable_factory
|
||||
|
||||
from ..common.impl import conversation_variable_updater_factory
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
|
||||
|
||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
_node_data_cls = VariableAssignerData
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=thread_pool_id,
|
||||
)
|
||||
self._conv_var_updater_factory = conv_var_updater_factory
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: VariableAssignerData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
mapping = {}
|
||||
assigned_variable_node_id = node_data.assigned_variable_selector[0]
|
||||
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
selector_key = ".".join(node_data.assigned_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = node_data.assigned_variable_selector
|
||||
|
||||
selector_key = ".".join(node_data.input_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = node_data.input_variable_selector
|
||||
return mapping
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
assigned_variable_selector = self.node_data.assigned_variable_selector
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
|
||||
@ -44,20 +104,28 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
|
||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||
|
||||
# TODO: Move database operation to the pipeline.
|
||||
# Update conversation variable.
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
raise VariableOperatorNodeError("conversation_id not found")
|
||||
common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
conv_var_updater = self._conv_var_updater_factory()
|
||||
conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
conv_var_updater.flush()
|
||||
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
"value": income_value.to_object(),
|
||||
},
|
||||
# NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`,
|
||||
# we still set `output_variables` as a list to ensure the schema of output is
|
||||
# compatible with `v2.VariableAssignerNode`.
|
||||
process_data=common_helpers.set_updated_variables({}, updated_variables),
|
||||
outputs={},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -12,6 +12,12 @@ class VariableOperationItem(BaseModel):
|
||||
variable_selector: Sequence[str]
|
||||
input_type: InputType
|
||||
operation: Operation
|
||||
# NOTE(QuantumGhost): The `value` field serves multiple purposes depending on context:
|
||||
#
|
||||
# 1. For CONSTANT input_type: Contains the literal value to be used in the operation.
|
||||
# 2. For VARIABLE input_type: Initially contains the selector of the source variable.
|
||||
# 3. During the variable updating procedure: The `value` field is reassigned to hold
|
||||
# the resolved actual value that will be applied to the target variable.
|
||||
value: Any | None = None
|
||||
|
||||
|
||||
|
||||
@ -29,3 +29,8 @@ class InvalidInputValueError(VariableOperatorNodeError):
|
||||
class ConversationIDNotFoundError(VariableOperatorNodeError):
|
||||
def __init__(self):
|
||||
super().__init__("conversation_id not found")
|
||||
|
||||
|
||||
class InvalidDataError(VariableOperatorNodeError):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
@ -1,34 +1,84 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
from collections.abc import Callable, Mapping, MutableMapping, Sequence
|
||||
from typing import Any, TypeAlias, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||
|
||||
from . import helpers
|
||||
from .constants import EMPTY_VALUE_MAPPING
|
||||
from .entities import VariableAssignerNodeData
|
||||
from .entities import VariableAssignerNodeData, VariableOperationItem
|
||||
from .enums import InputType, Operation
|
||||
from .exc import (
|
||||
ConversationIDNotFoundError,
|
||||
InputTypeNotSupportedError,
|
||||
InvalidDataError,
|
||||
InvalidInputValueError,
|
||||
OperationNotSupportedError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
|
||||
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
|
||||
selector_node_id = item.variable_selector[0]
|
||||
if selector_node_id != CONVERSATION_VARIABLE_NODE_ID:
|
||||
return
|
||||
selector_str = ".".join(item.variable_selector)
|
||||
key = f"{node_id}.#{selector_str}#"
|
||||
mapping[key] = item.variable_selector
|
||||
|
||||
|
||||
def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
|
||||
# Keep this in sync with the logic in _run methods...
|
||||
if item.input_type != InputType.VARIABLE:
|
||||
return
|
||||
selector = item.value
|
||||
if not isinstance(selector, list):
|
||||
raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}")
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
raise InvalidDataError(f"selector too short, {node_id=}, {item=}")
|
||||
selector_str = ".".join(selector)
|
||||
key = f"{node_id}.#{selector_str}#"
|
||||
mapping[key] = selector
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
|
||||
return conversation_variable_updater_factory()
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "2"
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: VariableAssignerNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
var_mapping: dict[str, Sequence[str]] = {}
|
||||
for item in node_data.items:
|
||||
_target_mapping_from_item(var_mapping, node_id, item)
|
||||
_source_mapping_from_item(var_mapping, node_id, item)
|
||||
return var_mapping
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
inputs = self.node_data.model_dump()
|
||||
process_data: dict[str, Any] = {}
|
||||
@ -114,6 +164,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
# remove the duplicated items first.
|
||||
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
|
||||
|
||||
conv_var_updater = self._conv_var_updater_factory()
|
||||
# Update variables
|
||||
for selector in updated_variable_selectors:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
@ -128,15 +179,23 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
raise ConversationIDNotFoundError
|
||||
else:
|
||||
conversation_id = conversation_id.value
|
||||
common_helpers.update_conversation_variable(
|
||||
conv_var_updater.update(
|
||||
conversation_id=cast(str, conversation_id),
|
||||
variable=variable,
|
||||
)
|
||||
conv_var_updater.flush()
|
||||
updated_variables = [
|
||||
common_helpers.variable_to_processed_data(selector, seg)
|
||||
for selector in updated_variable_selectors
|
||||
if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None
|
||||
]
|
||||
|
||||
process_data = common_helpers.set_updated_variables(process_data, updated_variables)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={},
|
||||
)
|
||||
|
||||
def _handle_item(
|
||||
|
||||
79
api/core/workflow/variable_loader.py
Normal file
79
api/core/workflow/variable_loader.py
Normal file
@ -0,0 +1,79 @@
|
||||
import abc
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.variables import Variable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
class VariableLoader(Protocol):
|
||||
"""Interface for loading variables based on selectors.
|
||||
|
||||
A `VariableLoader` is responsible for retrieving additional variables required during the execution
|
||||
of a single node, which are not provided as user inputs.
|
||||
|
||||
NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same
|
||||
application and share the same `app_id`. However, this interface does not enforce that constraint,
|
||||
and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of
|
||||
concern and allow for flexible implementations.
|
||||
|
||||
Implementations of `VariableLoader` should almost always have an `app_id` parameter in
|
||||
their constructor.
|
||||
|
||||
TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into
|
||||
`WorkflowService.single_step_run`, we may get rid of this interface.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||
"""Load variables based on the provided selectors. If the selectors are empty,
|
||||
this method should return an empty list.
|
||||
|
||||
The order of the returned variables is not guaranteed. If the caller wants to ensure
|
||||
a specific order, they should sort the returned list themselves.
|
||||
|
||||
:param: selectors: a list of string list, each inner list should have at least two elements:
|
||||
- the first element is the node ID,
|
||||
- the second element is the variable name.
|
||||
:return: a list of Variable objects that match the provided selectors.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class _DummyVariableLoader(VariableLoader):
|
||||
"""A dummy implementation of VariableLoader that does not load any variables.
|
||||
Serves as a placeholder when no variable loading is needed.
|
||||
"""
|
||||
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||
return []
|
||||
|
||||
|
||||
DUMMY_VARIABLE_LOADER = _DummyVariableLoader()
|
||||
|
||||
|
||||
def load_into_variable_pool(
|
||||
variable_loader: VariableLoader,
|
||||
variable_pool: VariablePool,
|
||||
variable_mapping: Mapping[str, Sequence[str]],
|
||||
user_inputs: Mapping[str, Any],
|
||||
):
|
||||
# Loading missing variable from draft var here, and set it into
|
||||
# variable_pool.
|
||||
variables_to_load: list[list[str]] = []
|
||||
for key, selector in variable_mapping.items():
|
||||
# NOTE(QuantumGhost): this logic needs to be in sync with
|
||||
# `WorkflowEntry.mapping_user_inputs_to_variable_pool`.
|
||||
node_variable_list = key.split(".")
|
||||
if len(node_variable_list) < 1:
|
||||
raise ValueError(f"Invalid variable key: {key}. It should have at least one element.")
|
||||
if key in user_inputs:
|
||||
continue
|
||||
node_variable_key = ".".join(node_variable_list[1:])
|
||||
if node_variable_key in user_inputs:
|
||||
continue
|
||||
if variable_pool.get(selector) is None:
|
||||
variables_to_load.append(list(selector))
|
||||
loaded = variable_loader.load_variables(variables_to_load)
|
||||
for var in loaded:
|
||||
variable_pool.add(var.selector, var)
|
||||
@ -92,7 +92,7 @@ class WorkflowCycleManager:
|
||||
) -> WorkflowExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
|
||||
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
# outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
|
||||
workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED
|
||||
workflow_execution.outputs = outputs or {}
|
||||
@ -125,7 +125,7 @@ class WorkflowCycleManager:
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowExecution:
|
||||
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
|
||||
# outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
|
||||
|
||||
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
|
||||
execution.outputs = outputs or {}
|
||||
@ -242,9 +242,9 @@ class WorkflowCycleManager:
|
||||
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
|
||||
|
||||
# Process data
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
inputs = event.inputs
|
||||
process_data = event.process_data
|
||||
outputs = event.outputs
|
||||
|
||||
# Convert metadata keys to strings
|
||||
execution_metadata_dict = {}
|
||||
@ -289,7 +289,7 @@ class WorkflowCycleManager:
|
||||
# Process data
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
outputs = event.outputs
|
||||
|
||||
# Convert metadata keys to strings
|
||||
execution_metadata_dict = {}
|
||||
@ -326,7 +326,7 @@ class WorkflowCycleManager:
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
outputs = event.outputs
|
||||
|
||||
# Convert metadata keys to strings
|
||||
origin_metadata = {
|
||||
|
||||
@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.event import NodeEvent
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from factories import file_factory
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import (
|
||||
@ -119,7 +120,9 @@ class WorkflowEntry:
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict,
|
||||
user_inputs: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
|
||||
"""
|
||||
Single step run workflow node
|
||||
@ -129,29 +132,14 @@ class WorkflowEntry:
|
||||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
# fetch node info from workflow graph
|
||||
workflow_graph = workflow.graph_dict
|
||||
if not workflow_graph:
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
nodes = workflow_graph.get("nodes")
|
||||
if not nodes:
|
||||
raise ValueError("nodes not found in workflow graph")
|
||||
|
||||
# fetch node config from node id
|
||||
try:
|
||||
node_config = next(filter(lambda node: node["id"] == node_id, nodes))
|
||||
except StopIteration:
|
||||
raise ValueError("node id not found in workflow graph")
|
||||
node_config = workflow.get_node_config_by_id(node_id)
|
||||
node_config_data = node_config.get("data", {})
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(node_config.get("data", {}).get("type"))
|
||||
node_version = node_config.get("data", {}).get("version", "1")
|
||||
node_type = NodeType(node_config_data.get("type"))
|
||||
node_version = node_config_data.get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(environment_variables=workflow.environment_variables)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=workflow.graph_dict)
|
||||
|
||||
@ -182,16 +170,33 @@ class WorkflowEntry:
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
# Loading missing variable from draft var here, and set it into
|
||||
# variable_pool.
|
||||
load_into_variable_pool(
|
||||
variable_loader=variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
)
|
||||
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
|
||||
workflow.id,
|
||||
node_instance.id,
|
||||
node_instance.node_type,
|
||||
node_instance.version(),
|
||||
)
|
||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||
return node_instance, generator
|
||||
|
||||
@ -294,10 +299,20 @@ class WorkflowEntry:
|
||||
|
||||
return node_instance, generator
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
|
||||
node_instance.id,
|
||||
node_instance.node_type,
|
||||
node_instance.version(),
|
||||
)
|
||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||
|
||||
@staticmethod
|
||||
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
|
||||
# NOTE(QuantumGhost): Avoid using this function in new code.
|
||||
# Keep values structured as long as possible and only convert to dict
|
||||
# immediately before serialization (e.g., JSON serialization) to maintain
|
||||
# data integrity and type information.
|
||||
result = WorkflowEntry._handle_special_values(value)
|
||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
||||
|
||||
@ -324,10 +339,17 @@ class WorkflowEntry:
|
||||
cls,
|
||||
*,
|
||||
variable_mapping: Mapping[str, Sequence[str]],
|
||||
user_inputs: dict,
|
||||
user_inputs: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
# NOTE(QuantumGhost): This logic should remain synchronized with
|
||||
# the implementation of `load_into_variable_pool`, specifically the logic about
|
||||
# variable existence checking.
|
||||
|
||||
# WARNING(QuantumGhost): The semantics of this method are not clearly defined,
|
||||
# and multiple parts of the codebase depend on its current behavior.
|
||||
# Modify with caution.
|
||||
for node_variable, variable_selector in variable_mapping.items():
|
||||
# fetch node id and variable key from node_variable
|
||||
node_variable_list = node_variable.split(".")
|
||||
|
||||
49
api/core/workflow/workflow_type_encoder.py
Normal file
49
api/core/workflow/workflow_type_encoder.py
Normal file
@ -0,0 +1,49 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file.models import File
|
||||
from core.variables import Segment
|
||||
|
||||
|
||||
class WorkflowRuntimeTypeEncoder(json.JSONEncoder):
|
||||
def default(self, o: Any):
|
||||
if isinstance(o, Segment):
|
||||
return o.value
|
||||
elif isinstance(o, File):
|
||||
return o.to_dict()
|
||||
elif isinstance(o, BaseModel):
|
||||
return o.model_dump(mode="json")
|
||||
else:
|
||||
return super().default(o)
|
||||
|
||||
|
||||
class WorkflowRuntimeTypeConverter:
|
||||
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
||||
result = self._to_json_encodable_recursive(value)
|
||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
||||
|
||||
def _to_json_encodable_recursive(self, value: Any) -> Any:
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, (bool, int, str, float)):
|
||||
return value
|
||||
if isinstance(value, Segment):
|
||||
return self._to_json_encodable_recursive(value.value)
|
||||
if isinstance(value, File):
|
||||
return value.to_dict()
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(mode="json")
|
||||
if isinstance(value, dict):
|
||||
res = {}
|
||||
for k, v in value.items():
|
||||
res[k] = self._to_json_encodable_recursive(v)
|
||||
return res
|
||||
if isinstance(value, list):
|
||||
res_list = []
|
||||
for item in value:
|
||||
res_list.append(self._to_json_encodable_recursive(item))
|
||||
return res_list
|
||||
return value
|
||||
Reference in New Issue
Block a user