mirror of
https://github.com/langgenius/dify.git
synced 2026-03-04 23:36:20 +08:00
275 lines
11 KiB
Python
275 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from collections import defaultdict
|
|
from collections.abc import Mapping, Sequence
|
|
from copy import deepcopy
|
|
from typing import Annotated, Any, Union, cast
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from dify_graph.constants import (
|
|
CONVERSATION_VARIABLE_NODE_ID,
|
|
ENVIRONMENT_VARIABLE_NODE_ID,
|
|
RAG_PIPELINE_VARIABLE_NODE_ID,
|
|
SYSTEM_VARIABLE_NODE_ID,
|
|
)
|
|
from dify_graph.file import File, FileAttribute, file_manager
|
|
from dify_graph.system_variable import SystemVariable
|
|
from dify_graph.variables import Segment, SegmentGroup, VariableBase
|
|
from dify_graph.variables.consts import SELECTORS_LENGTH
|
|
from dify_graph.variables.segments import FileSegment, ObjectSegment
|
|
from dify_graph.variables.variables import RAGPipelineVariableInput, Variable
|
|
from factories import variable_factory
|
|
|
|
VariableValue = Union[str, int, float, dict[str, object], list[object], File]
|
|
|
|
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
|
|
|
|
|
class VariablePool(BaseModel):
|
|
# Variable dictionary is a dictionary for looking up variables by their selector.
|
|
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
|
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
|
# elements of the selector except the first one.
|
|
variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
|
|
description="Variables mapping",
|
|
default=defaultdict(dict),
|
|
)
|
|
|
|
# The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere.
|
|
user_inputs: Mapping[str, Any] = Field(
|
|
description="User inputs",
|
|
default_factory=dict,
|
|
)
|
|
system_variables: SystemVariable = Field(
|
|
description="System variables",
|
|
default_factory=SystemVariable.default,
|
|
)
|
|
environment_variables: Sequence[Variable] = Field(
|
|
description="Environment variables.",
|
|
default_factory=list[Variable],
|
|
)
|
|
conversation_variables: Sequence[Variable] = Field(
|
|
description="Conversation variables.",
|
|
default_factory=list[Variable],
|
|
)
|
|
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
|
|
description="RAG pipeline variables.",
|
|
default_factory=list,
|
|
)
|
|
|
|
def model_post_init(self, context: Any, /):
|
|
# Create a mapping from field names to SystemVariableKey enum values
|
|
self._add_system_variables(self.system_variables)
|
|
# Add environment variables to the variable pool
|
|
for var in self.environment_variables:
|
|
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
|
# Add conversation variables to the variable pool
|
|
for var in self.conversation_variables:
|
|
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
|
# Add rag pipeline variables to the variable pool
|
|
if self.rag_pipeline_variables:
|
|
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)
|
|
for rag_var in self.rag_pipeline_variables:
|
|
node_id = rag_var.variable.belong_to_node_id
|
|
key = rag_var.variable.variable
|
|
value = rag_var.value
|
|
rag_pipeline_variables_map[node_id][key] = value
|
|
for key, value in rag_pipeline_variables_map.items():
|
|
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
|
|
|
|
def add(self, selector: Sequence[str], value: Any, /):
|
|
"""
|
|
Add a variable to the variable pool.
|
|
|
|
This method accepts a selector path and a value, converting the value
|
|
to a Variable object if necessary before storing it in the pool.
|
|
|
|
Args:
|
|
selector: A two-element sequence containing [node_id, variable_name].
|
|
The selector must have exactly 2 elements to be valid.
|
|
value: The value to store. Can be a Variable, Segment, or any value
|
|
that can be converted to a Segment (str, int, float, dict, list, File).
|
|
|
|
Raises:
|
|
ValueError: If selector length is not exactly 2 elements.
|
|
|
|
Note:
|
|
While non-Segment values are currently accepted and automatically
|
|
converted, it's recommended to pass Segment or Variable objects directly.
|
|
"""
|
|
if len(selector) != SELECTORS_LENGTH:
|
|
raise ValueError(
|
|
f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), "
|
|
f"got {len(selector)} elements"
|
|
)
|
|
|
|
if isinstance(value, VariableBase):
|
|
variable = value
|
|
elif isinstance(value, Segment):
|
|
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
|
|
else:
|
|
segment = variable_factory.build_segment(value)
|
|
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
|
|
|
node_id, name = self._selector_to_keys(selector)
|
|
# Based on the definition of `Variable`,
|
|
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
|
self.variable_dictionary[node_id][name] = cast(Variable, variable)
|
|
|
|
@classmethod
|
|
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
|
|
return selector[0], selector[1]
|
|
|
|
def _has(self, selector: Sequence[str]) -> bool:
|
|
node_id, name = self._selector_to_keys(selector)
|
|
if node_id not in self.variable_dictionary:
|
|
return False
|
|
if name not in self.variable_dictionary[node_id]:
|
|
return False
|
|
return True
|
|
|
|
def get(self, selector: Sequence[str], /) -> Segment | None:
|
|
"""
|
|
Retrieve a variable's value from the pool as a Segment.
|
|
|
|
This method supports both simple selectors [node_id, variable_name] and
|
|
extended selectors that include attribute access for FileSegment and
|
|
ObjectSegment types.
|
|
|
|
Args:
|
|
selector: A sequence with at least 2 elements:
|
|
- [node_id, variable_name]: Returns the full segment
|
|
- [node_id, variable_name, attr, ...]: Returns a nested value
|
|
from FileSegment (e.g., 'url', 'name') or ObjectSegment
|
|
|
|
Returns:
|
|
The Segment associated with the selector, or None if not found.
|
|
Returns None if selector has fewer than 2 elements.
|
|
|
|
Raises:
|
|
ValueError: If attempting to access an invalid FileAttribute.
|
|
"""
|
|
if len(selector) < SELECTORS_LENGTH:
|
|
return None
|
|
|
|
node_id, name = self._selector_to_keys(selector)
|
|
node_map = self.variable_dictionary.get(node_id)
|
|
if node_map is None:
|
|
return None
|
|
|
|
segment: Segment | None = node_map.get(name)
|
|
|
|
if segment is None:
|
|
return None
|
|
|
|
if len(selector) == 2:
|
|
return segment
|
|
|
|
if isinstance(segment, FileSegment):
|
|
attr = selector[2]
|
|
# Python support `attr in FileAttribute` after 3.12
|
|
if attr not in {item.value for item in FileAttribute}:
|
|
return None
|
|
attr = FileAttribute(attr)
|
|
attr_value = file_manager.get_attr(file=segment.value, attr=attr)
|
|
return variable_factory.build_segment(attr_value)
|
|
|
|
# Navigate through nested attributes
|
|
result: Any = segment
|
|
for attr in selector[2:]:
|
|
result = self._extract_value(result)
|
|
result = self._get_nested_attribute(result, attr)
|
|
if result is None:
|
|
return None
|
|
|
|
# Return result as Segment
|
|
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
|
|
|
|
def _extract_value(self, obj: Any):
|
|
"""Extract the actual value from an ObjectSegment."""
|
|
return obj.value if isinstance(obj, ObjectSegment) else obj
|
|
|
|
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None:
|
|
"""
|
|
Get a nested attribute from a dictionary-like object.
|
|
|
|
Args:
|
|
obj: The dictionary-like object to search.
|
|
attr: The key to look up.
|
|
|
|
Returns:
|
|
Segment | None:
|
|
The corresponding Segment built from the attribute value if the key exists,
|
|
otherwise None.
|
|
"""
|
|
if not isinstance(obj, dict) or attr not in obj:
|
|
return None
|
|
return variable_factory.build_segment(obj.get(attr))
|
|
|
|
def remove(self, selector: Sequence[str], /):
|
|
"""
|
|
Remove variables from the variable pool based on the given selector.
|
|
|
|
Args:
|
|
selector (Sequence[str]): A sequence of strings representing the selector.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
if not selector:
|
|
return
|
|
if len(selector) == 1:
|
|
self.variable_dictionary[selector[0]] = {}
|
|
return
|
|
key, hash_key = self._selector_to_keys(selector)
|
|
self.variable_dictionary[key].pop(hash_key, None)
|
|
|
|
def convert_template(self, template: str, /):
|
|
parts = VARIABLE_PATTERN.split(template)
|
|
segments: list[Segment] = []
|
|
for part in filter(lambda x: x, parts):
|
|
if "." in part and (variable := self.get(part.split("."))):
|
|
segments.append(variable)
|
|
else:
|
|
segments.append(variable_factory.build_segment(part))
|
|
return SegmentGroup(value=segments)
|
|
|
|
def get_file(self, selector: Sequence[str], /) -> FileSegment | None:
|
|
segment = self.get(selector)
|
|
if isinstance(segment, FileSegment):
|
|
return segment
|
|
return None
|
|
|
|
def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]:
|
|
"""Return a copy of all variables stored under the given node prefix."""
|
|
|
|
nodes = self.variable_dictionary.get(prefix)
|
|
if not nodes:
|
|
return {}
|
|
|
|
result: dict[str, object] = {}
|
|
for key, variable in nodes.items():
|
|
value = variable.value
|
|
result[key] = deepcopy(value)
|
|
|
|
return result
|
|
|
|
def _add_system_variables(self, system_variable: SystemVariable):
|
|
sys_var_mapping = system_variable.to_dict()
|
|
for key, value in sys_var_mapping.items():
|
|
if value is None:
|
|
continue
|
|
selector = (SYSTEM_VARIABLE_NODE_ID, key)
|
|
# If the system variable already exists, do not add it again.
|
|
# This ensures that we can keep the id of the system variables intact.
|
|
if self._has(selector):
|
|
continue
|
|
self.add(selector, value)
|
|
|
|
@classmethod
|
|
def empty(cls) -> VariablePool:
|
|
"""Create an empty variable pool."""
|
|
return cls(system_variables=SystemVariable.default())
|