mirror of
https://github.com/langgenius/dify.git
synced 2026-04-30 07:28:05 +08:00
refactor(api): rename dify_graph to graphon (#34095)
This commit is contained in:
279
api/graphon/runtime/variable_pool.py
Normal file
279
api/graphon/runtime/variable_pool.py
Normal file
@ -0,0 +1,279 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Annotated, Any, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from graphon.file import File, FileAttribute, file_manager
|
||||
from graphon.variables import Segment, SegmentGroup, VariableBase, build_segment, segment_to_variable
|
||||
from graphon.variables.consts import SELECTORS_LENGTH
|
||||
from graphon.variables.segments import FileSegment, ObjectSegment
|
||||
from graphon.variables.variables import RAGPipelineVariableInput, Variable
|
||||
|
||||
VariableValue = Union[str, int, float, dict[str, object], list[object], File]
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
||||
|
||||
|
||||
def _default_variable_dictionary() -> defaultdict[str, dict[str, Variable]]:
|
||||
return defaultdict(dict)
|
||||
|
||||
|
||||
class VariablePool(BaseModel):
|
||||
_SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||
_ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||
_CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
_RAG_PIPELINE_VARIABLE_NODE_ID = "rag"
|
||||
|
||||
# Variable dictionary is a dictionary for looking up variables by their selector.
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
|
||||
description="Variables mapping",
|
||||
default_factory=_default_variable_dictionary,
|
||||
)
|
||||
system_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True)
|
||||
environment_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True)
|
||||
conversation_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True)
|
||||
rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = Field(default_factory=tuple, exclude=True)
|
||||
user_inputs: Mapping[str, Any] = Field(default_factory=dict, exclude=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _load_legacy_bootstrap_inputs(self) -> VariablePool:
|
||||
"""
|
||||
Accept legacy constructor kwargs that still appear throughout the workflow
|
||||
layer while keeping serialized state focused on `variable_dictionary`.
|
||||
"""
|
||||
|
||||
self._ingest_legacy_variables(self.system_variables, node_id=self._SYSTEM_VARIABLE_NODE_ID)
|
||||
self._ingest_legacy_variables(self.environment_variables, node_id=self._ENVIRONMENT_VARIABLE_NODE_ID)
|
||||
self._ingest_legacy_variables(self.conversation_variables, node_id=self._CONVERSATION_VARIABLE_NODE_ID)
|
||||
self._ingest_legacy_rag_variables(self.rag_pipeline_variables)
|
||||
|
||||
# These kwargs are accepted for compatibility but should not affect the
|
||||
# stable serialized form or model equality.
|
||||
self.system_variables = ()
|
||||
self.environment_variables = ()
|
||||
self.conversation_variables = ()
|
||||
self.rag_pipeline_variables = ()
|
||||
self.user_inputs = {}
|
||||
return self
|
||||
|
||||
def _ingest_legacy_variables(self, variables: Sequence[Variable], *, node_id: str) -> None:
|
||||
for variable in variables:
|
||||
selector = [node_id, variable.name]
|
||||
normalized_variable = variable
|
||||
if list(variable.selector) != selector:
|
||||
normalized_variable = variable.model_copy(update={"selector": selector})
|
||||
self.add(normalized_variable.selector, normalized_variable)
|
||||
|
||||
def _ingest_legacy_rag_variables(self, rag_pipeline_variables: Sequence[RAGPipelineVariableInput]) -> None:
|
||||
if not rag_pipeline_variables:
|
||||
return
|
||||
|
||||
values_by_node_id: defaultdict[str, dict[str, Any]] = defaultdict(dict)
|
||||
for rag_variable_input in rag_pipeline_variables:
|
||||
values_by_node_id[rag_variable_input.variable.belong_to_node_id][rag_variable_input.variable.variable] = (
|
||||
rag_variable_input.value
|
||||
)
|
||||
|
||||
for node_id, value in values_by_node_id.items():
|
||||
self.add((self._RAG_PIPELINE_VARIABLE_NODE_ID, node_id), value)
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /):
|
||||
"""
|
||||
Add a variable to the variable pool.
|
||||
|
||||
This method accepts a selector path and a value, converting the value
|
||||
to a Variable object if necessary before storing it in the pool.
|
||||
|
||||
Args:
|
||||
selector: A two-element sequence containing [node_id, variable_name].
|
||||
The selector must have exactly 2 elements to be valid.
|
||||
value: The value to store. Can be a Variable, Segment, or any value
|
||||
that can be converted to a Segment (str, int, float, dict, list, File).
|
||||
|
||||
Raises:
|
||||
ValueError: If selector length is not exactly 2 elements.
|
||||
|
||||
Note:
|
||||
While non-Segment values are currently accepted and automatically
|
||||
converted, it's recommended to pass Segment or Variable objects directly.
|
||||
"""
|
||||
if len(selector) != SELECTORS_LENGTH:
|
||||
raise ValueError(
|
||||
f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), "
|
||||
f"got {len(selector)} elements"
|
||||
)
|
||||
|
||||
if isinstance(value, VariableBase):
|
||||
variable = value
|
||||
elif isinstance(value, Segment):
|
||||
variable = segment_to_variable(segment=value, selector=selector)
|
||||
else:
|
||||
segment = build_segment(value)
|
||||
variable = segment_to_variable(segment=segment, selector=selector)
|
||||
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
# Based on the definition of `Variable`,
|
||||
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
||||
self.variable_dictionary[node_id][name] = cast(Variable, variable)
|
||||
|
||||
@classmethod
|
||||
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
|
||||
return selector[0], selector[1]
|
||||
|
||||
def _has(self, selector: Sequence[str]) -> bool:
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
if node_id not in self.variable_dictionary:
|
||||
return False
|
||||
if name not in self.variable_dictionary[node_id]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""
|
||||
Retrieve a variable's value from the pool as a Segment.
|
||||
|
||||
This method supports both simple selectors [node_id, variable_name] and
|
||||
extended selectors that include attribute access for FileSegment and
|
||||
ObjectSegment types.
|
||||
|
||||
Args:
|
||||
selector: A sequence with at least 2 elements:
|
||||
- [node_id, variable_name]: Returns the full segment
|
||||
- [node_id, variable_name, attr, ...]: Returns a nested value
|
||||
from FileSegment (e.g., 'url', 'name') or ObjectSegment
|
||||
|
||||
Returns:
|
||||
The Segment associated with the selector, or None if not found.
|
||||
Returns None if selector has fewer than 2 elements.
|
||||
|
||||
Raises:
|
||||
ValueError: If attempting to access an invalid FileAttribute.
|
||||
"""
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
return None
|
||||
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
node_map = self.variable_dictionary.get(node_id)
|
||||
if node_map is None:
|
||||
return None
|
||||
|
||||
segment: Segment | None = node_map.get(name)
|
||||
|
||||
if segment is None:
|
||||
return None
|
||||
|
||||
if len(selector) == 2:
|
||||
return segment
|
||||
|
||||
if isinstance(segment, FileSegment):
|
||||
attr = selector[2]
|
||||
# Python support `attr in FileAttribute` after 3.12
|
||||
if attr not in {item.value for item in FileAttribute}:
|
||||
return None
|
||||
attr = FileAttribute(attr)
|
||||
attr_value = file_manager.get_attr(file=segment.value, attr=attr)
|
||||
return build_segment(attr_value)
|
||||
|
||||
# Navigate through nested attributes
|
||||
result: Any = segment
|
||||
for attr in selector[2:]:
|
||||
result = self._extract_value(result)
|
||||
result = self._get_nested_attribute(result, attr)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
# Return result as Segment
|
||||
return result if isinstance(result, Segment) else build_segment(result)
|
||||
|
||||
def _extract_value(self, obj: Any):
|
||||
"""Extract the actual value from an ObjectSegment."""
|
||||
return obj.value if isinstance(obj, ObjectSegment) else obj
|
||||
|
||||
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None:
|
||||
"""
|
||||
Get a nested attribute from a dictionary-like object.
|
||||
|
||||
Args:
|
||||
obj: The dictionary-like object to search.
|
||||
attr: The key to look up.
|
||||
|
||||
Returns:
|
||||
Segment | None:
|
||||
The corresponding Segment built from the attribute value if the key exists,
|
||||
otherwise None.
|
||||
"""
|
||||
if not isinstance(obj, dict) or attr not in obj:
|
||||
return None
|
||||
return build_segment(obj.get(attr))
|
||||
|
||||
def remove(self, selector: Sequence[str], /):
|
||||
"""
|
||||
Remove variables from the variable pool based on the given selector.
|
||||
|
||||
Args:
|
||||
selector (Sequence[str]): A sequence of strings representing the selector.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not selector:
|
||||
return
|
||||
if len(selector) == 1:
|
||||
self.variable_dictionary[selector[0]] = {}
|
||||
return
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
self.variable_dictionary[key].pop(hash_key, None)
|
||||
|
||||
def convert_template(self, template: str, /):
|
||||
parts = VARIABLE_PATTERN.split(template)
|
||||
segments: list[Segment] = []
|
||||
for part in filter(lambda x: x, parts):
|
||||
if "." in part and (variable := self.get(part.split("."))):
|
||||
segments.append(variable)
|
||||
else:
|
||||
segments.append(build_segment(part))
|
||||
return SegmentGroup(value=segments)
|
||||
|
||||
def get_file(self, selector: Sequence[str], /) -> FileSegment | None:
|
||||
segment = self.get(selector)
|
||||
if isinstance(segment, FileSegment):
|
||||
return segment
|
||||
return None
|
||||
|
||||
def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]:
|
||||
"""Return a copy of all variables stored under the given node prefix."""
|
||||
|
||||
nodes = self.variable_dictionary.get(prefix)
|
||||
if not nodes:
|
||||
return {}
|
||||
|
||||
result: dict[str, object] = {}
|
||||
for key, variable in nodes.items():
|
||||
value = variable.value
|
||||
result[key] = deepcopy(value)
|
||||
|
||||
return result
|
||||
|
||||
def flatten(self, *, unprefixed_node_id: str | None = None) -> Mapping[str, object]:
|
||||
"""Return a selector-style snapshot of the entire variable pool."""
|
||||
|
||||
result: dict[str, object] = {}
|
||||
for node_id, variables in self.variable_dictionary.items():
|
||||
for name, variable in variables.items():
|
||||
output_name = name if node_id == unprefixed_node_id else f"{node_id}.{name}"
|
||||
result[output_name] = deepcopy(variable.value)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> VariablePool:
|
||||
"""Create an empty variable pool."""
|
||||
return cls()
|
||||
Reference in New Issue
Block a user