mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 00:38:03 +08:00
Made-with: Cursor # Conflicts: # api/controllers/console/app/workflow_draft_variable.py # api/core/agent/cot_agent_runner.py # api/core/agent/cot_chat_agent_runner.py # api/core/agent/cot_completion_agent_runner.py # api/core/agent/fc_agent_runner.py # api/core/app/apps/advanced_chat/app_generator.py # api/core/app/apps/advanced_chat/app_runner.py # api/core/app/apps/agent_chat/app_runner.py # api/core/app/apps/workflow/app_generator.py # api/core/app/apps/workflow/app_runner.py # api/core/app/entities/app_invoke_entities.py # api/core/app/entities/queue_entities.py # api/core/llm_generator/output_parser/structured_output.py # api/core/workflow/workflow_entry.py # api/dify_graph/context/__init__.py # api/dify_graph/entities/tool_entities.py # api/dify_graph/file/file_manager.py # api/dify_graph/graph_engine/response_coordinator/coordinator.py # api/dify_graph/graph_events/node.py # api/dify_graph/node_events/node.py # api/dify_graph/nodes/agent/agent_node.py # api/dify_graph/nodes/llm/entities.py # api/dify_graph/nodes/llm/llm_utils.py # api/dify_graph/nodes/llm/node.py # api/dify_graph/nodes/question_classifier/question_classifier_node.py # api/dify_graph/runtime/graph_runtime_state.py # api/dify_graph/variables/segments.py # api/factories/variable_factory.py # api/services/variable_truncator.py # api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py # api/uv.lock # web/app/components/app-sidebar/app-info.tsx # web/app/components/app-sidebar/app-sidebar-dropdown.tsx # web/app/components/app/create-app-modal/index.spec.tsx # web/app/components/apps/__tests__/list.spec.tsx # web/app/components/apps/app-card.tsx # web/app/components/apps/list.tsx # web/app/components/header/account-dropdown/compliance.tsx # web/app/components/header/account-dropdown/index.tsx # web/app/components/header/account-dropdown/support.tsx # web/app/components/workflow-app/components/workflow-onboarding-modal/index.tsx # web/app/components/workflow/panel/debug-and-preview/hooks.ts # web/contract/console/apps.ts # web/contract/router.ts # web/eslint-suppressions.json # web/next.config.ts # web/pnpm-lock.yaml
313 lines
12 KiB
Python
313 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from collections import defaultdict
|
|
from collections.abc import Mapping, Sequence
|
|
from copy import deepcopy
|
|
from typing import Annotated, Any, Union, cast
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from dify_graph.constants import (
|
|
CONVERSATION_VARIABLE_NODE_ID,
|
|
ENVIRONMENT_VARIABLE_NODE_ID,
|
|
RAG_PIPELINE_VARIABLE_NODE_ID,
|
|
SYSTEM_VARIABLE_NODE_ID,
|
|
)
|
|
from dify_graph.file import File, FileAttribute, file_manager
|
|
from dify_graph.system_variable import SystemVariable
|
|
from dify_graph.variables import Segment, SegmentGroup, VariableBase
|
|
from dify_graph.variables.consts import SELECTORS_LENGTH
|
|
from dify_graph.variables.segments import FileSegment, ObjectSegment
|
|
from dify_graph.variables.variables import RAGPipelineVariableInput, Variable
|
|
from factories import variable_factory
|
|
|
|
VariableValue = Union[str, int, float, dict[str, object], list[object], File]
|
|
|
|
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
|
|
|
|
|
class VariablePool(BaseModel):
|
|
# Variable dictionary is a dictionary for looking up variables by their selector.
|
|
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
|
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
|
# elements of the selector except the first one.
|
|
variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
|
|
description="Variables mapping",
|
|
default=defaultdict(dict),
|
|
)
|
|
|
|
# The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere.
|
|
user_inputs: Mapping[str, Any] = Field(
|
|
description="User inputs",
|
|
default_factory=dict,
|
|
)
|
|
system_variables: SystemVariable = Field(
|
|
description="System variables",
|
|
default_factory=SystemVariable.default,
|
|
)
|
|
environment_variables: Sequence[Variable] = Field(
|
|
description="Environment variables.",
|
|
default_factory=list[Variable],
|
|
)
|
|
conversation_variables: Sequence[Variable] = Field(
|
|
description="Conversation variables.",
|
|
default_factory=list[Variable],
|
|
)
|
|
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
|
|
description="RAG pipeline variables.",
|
|
default_factory=list,
|
|
)
|
|
|
|
def model_post_init(self, context: Any, /):
|
|
# Create a mapping from field names to SystemVariableKey enum values
|
|
self._add_system_variables(self.system_variables)
|
|
# Add environment variables to the variable pool
|
|
for var in self.environment_variables:
|
|
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
|
# Add conversation variables to the variable pool
|
|
for var in self.conversation_variables:
|
|
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
|
# Add rag pipeline variables to the variable pool
|
|
if self.rag_pipeline_variables:
|
|
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)
|
|
for rag_var in self.rag_pipeline_variables:
|
|
node_id = rag_var.variable.belong_to_node_id
|
|
key = rag_var.variable.variable
|
|
value = rag_var.value
|
|
rag_pipeline_variables_map[node_id][key] = value
|
|
for key, value in rag_pipeline_variables_map.items():
|
|
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
|
|
|
|
def add(self, selector: Sequence[str], value: Any, /):
|
|
"""
|
|
Add a variable to the variable pool.
|
|
|
|
This method accepts a selector path and a value, converting the value
|
|
to a Variable object if necessary before storing it in the pool.
|
|
|
|
Args:
|
|
selector: A two-element sequence containing [node_id, variable_name].
|
|
The selector must have exactly 2 elements to be valid.
|
|
value: The value to store. Can be a Variable, Segment, or any value
|
|
that can be converted to a Segment (str, int, float, dict, list, File).
|
|
|
|
Raises:
|
|
ValueError: If selector length is not exactly 2 elements.
|
|
|
|
Note:
|
|
While non-Segment values are currently accepted and automatically
|
|
converted, it's recommended to pass Segment or Variable objects directly.
|
|
"""
|
|
if len(selector) != SELECTORS_LENGTH:
|
|
raise ValueError(
|
|
f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), "
|
|
f"got {len(selector)} elements"
|
|
)
|
|
|
|
if isinstance(value, VariableBase):
|
|
variable = value
|
|
elif isinstance(value, Segment):
|
|
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
|
|
else:
|
|
segment = variable_factory.build_segment(value)
|
|
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
|
|
|
node_id, name = self._selector_to_keys(selector)
|
|
# Based on the definition of `Variable`,
|
|
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
|
self.variable_dictionary[node_id][name] = cast(Variable, variable)
|
|
|
|
@classmethod
|
|
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
|
|
return selector[0], selector[1]
|
|
|
|
def _has(self, selector: Sequence[str]) -> bool:
|
|
node_id, name = self._selector_to_keys(selector)
|
|
if node_id not in self.variable_dictionary:
|
|
return False
|
|
if name not in self.variable_dictionary[node_id]:
|
|
return False
|
|
return True
|
|
|
|
def get(self, selector: Sequence[str], /) -> Segment | None:
|
|
"""
|
|
Retrieve a variable's value from the pool as a Segment.
|
|
|
|
This method supports both simple selectors [node_id, variable_name] and
|
|
extended selectors that include attribute access for FileSegment and
|
|
ObjectSegment types.
|
|
|
|
Args:
|
|
selector: A sequence with at least 2 elements:
|
|
- [node_id, variable_name]: Returns the full segment
|
|
- [node_id, variable_name, attr, ...]: Returns a nested value
|
|
from FileSegment (e.g., 'url', 'name') or ObjectSegment
|
|
|
|
Returns:
|
|
The Segment associated with the selector, or None if not found.
|
|
Returns None if selector has fewer than 2 elements.
|
|
|
|
Raises:
|
|
ValueError: If attempting to access an invalid FileAttribute.
|
|
"""
|
|
if len(selector) < SELECTORS_LENGTH:
|
|
return None
|
|
|
|
node_id, name = self._selector_to_keys(selector)
|
|
node_map = self.variable_dictionary.get(node_id)
|
|
if node_map is None:
|
|
return None
|
|
|
|
segment: Segment | None = node_map.get(name)
|
|
|
|
if segment is None:
|
|
return None
|
|
|
|
if len(selector) == 2:
|
|
return segment
|
|
|
|
if isinstance(segment, FileSegment):
|
|
attr = selector[2]
|
|
# Python support `attr in FileAttribute` after 3.12
|
|
if attr not in {item.value for item in FileAttribute}:
|
|
return None
|
|
attr = FileAttribute(attr)
|
|
attr_value = file_manager.get_attr(file=segment.value, attr=attr)
|
|
return variable_factory.build_segment(attr_value)
|
|
|
|
# Navigate through nested attributes
|
|
result: Any = segment
|
|
for attr in selector[2:]:
|
|
result = self._extract_value(result)
|
|
result = self._get_nested_attribute(result, attr)
|
|
if result is None:
|
|
return None
|
|
|
|
# Return result as Segment
|
|
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
|
|
|
|
def _extract_value(self, obj: Any):
|
|
"""Extract the actual value from an ObjectSegment."""
|
|
return obj.value if isinstance(obj, ObjectSegment) else obj
|
|
|
|
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None:
|
|
"""
|
|
Get a nested attribute from a dictionary-like object.
|
|
|
|
Args:
|
|
obj: The dictionary-like object to search.
|
|
attr: The key to look up.
|
|
|
|
Returns:
|
|
Segment | None:
|
|
The corresponding Segment built from the attribute value if the key exists,
|
|
otherwise None.
|
|
"""
|
|
if not isinstance(obj, dict) or attr not in obj:
|
|
return None
|
|
return variable_factory.build_segment(obj.get(attr))
|
|
|
|
def remove(self, selector: Sequence[str], /):
|
|
"""
|
|
Remove variables from the variable pool based on the given selector.
|
|
|
|
Args:
|
|
selector (Sequence[str]): A sequence of strings representing the selector.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
if not selector:
|
|
return
|
|
if len(selector) == 1:
|
|
self.variable_dictionary[selector[0]] = {}
|
|
return
|
|
key, hash_key = self._selector_to_keys(selector)
|
|
self.variable_dictionary[key].pop(hash_key, None)
|
|
|
|
def convert_template(self, template: str, /):
|
|
parts = VARIABLE_PATTERN.split(template)
|
|
segments: list[Segment] = []
|
|
for part in filter(lambda x: x, parts):
|
|
if "." in part and (variable := self.get(part.split("."))):
|
|
segments.append(variable)
|
|
else:
|
|
segments.append(variable_factory.build_segment(part))
|
|
return SegmentGroup(value=segments)
|
|
|
|
def get_file(self, selector: Sequence[str], /) -> FileSegment | None:
|
|
segment = self.get(selector)
|
|
if isinstance(segment, FileSegment):
|
|
return segment
|
|
return None
|
|
|
|
def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]:
|
|
"""Return a copy of all variables stored under the given node prefix."""
|
|
|
|
nodes = self.variable_dictionary.get(prefix)
|
|
if not nodes:
|
|
return {}
|
|
|
|
result: dict[str, object] = {}
|
|
for key, variable in nodes.items():
|
|
value = variable.value
|
|
result[key] = deepcopy(value)
|
|
|
|
return result
|
|
|
|
def _add_system_variables(self, system_variable: SystemVariable):
|
|
sys_var_mapping = system_variable.to_dict()
|
|
for key, value in sys_var_mapping.items():
|
|
if value is None:
|
|
continue
|
|
selector = (SYSTEM_VARIABLE_NODE_ID, key)
|
|
# If the system variable already exists, do not add it again.
|
|
# This ensures that we can keep the id of the system variables intact.
|
|
if self._has(selector):
|
|
continue
|
|
self.add(selector, value)
|
|
|
|
def resolve_nested_node(
|
|
self,
|
|
nested_node_config: Mapping[str, Any],
|
|
/,
|
|
*,
|
|
use_extractor: bool = True,
|
|
parameter_name: str = "",
|
|
) -> tuple[Any, bool]:
|
|
"""
|
|
Resolve a nested_node parameter value.
|
|
|
|
Args:
|
|
nested_node_config: Config dict containing output_selector, null_strategy, default_value,
|
|
and optionally extractor_node_id
|
|
use_extractor: If True (mention format), build selector as extractor_node_id + output_selector.
|
|
If False (variable format), use output_selector directly.
|
|
parameter_name: Name of the parameter being resolved (for error messages)
|
|
|
|
Returns:
|
|
Tuple of (resolved_value, found)
|
|
"""
|
|
output_selector = list(nested_node_config.get("output_selector", []))
|
|
null_strategy = nested_node_config.get("null_strategy", "raise_error")
|
|
default_value = nested_node_config.get("default_value")
|
|
|
|
extractor_node_id = nested_node_config.get("extractor_node_id")
|
|
if not extractor_node_id:
|
|
raise ValueError(f"Missing extractor_node_id for parameter '{parameter_name}'")
|
|
full_selector = [extractor_node_id] + output_selector
|
|
|
|
variable = self.get(full_selector)
|
|
if variable is None:
|
|
if null_strategy == "use_default":
|
|
return default_value, False
|
|
raise ValueError(f"Variable {full_selector} not found for parameter '{parameter_name}'")
|
|
|
|
return variable.value, True
|
|
|
|
@classmethod
|
|
def empty(cls) -> VariablePool:
|
|
"""Create an empty variable pool."""
|
|
return cls(system_variables=SystemVariable.default())
|