from __future__ import annotations import re from collections import defaultdict from collections.abc import Mapping, Sequence from copy import deepcopy from typing import Annotated, Any, Union, cast from pydantic import BaseModel, Field, model_validator from graphon.file import File, FileAttribute, file_manager from graphon.variables import Segment, SegmentGroup, VariableBase, build_segment, segment_to_variable from graphon.variables.consts import SELECTORS_LENGTH from graphon.variables.segments import FileSegment, ObjectSegment from graphon.variables.variables import RAGPipelineVariableInput, Variable VariableValue = Union[str, int, float, dict[str, object], list[object], File] VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") def _default_variable_dictionary() -> defaultdict[str, dict[str, Variable]]: return defaultdict(dict) class VariablePool(BaseModel): _SYSTEM_VARIABLE_NODE_ID = "sys" _ENVIRONMENT_VARIABLE_NODE_ID = "env" _CONVERSATION_VARIABLE_NODE_ID = "conversation" _RAG_PIPELINE_VARIABLE_NODE_ID = "rag" # Variable dictionary is a dictionary for looking up variables by their selector. # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field( description="Variables mapping", default_factory=_default_variable_dictionary, ) system_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) environment_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) conversation_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = Field(default_factory=tuple, exclude=True) user_inputs: Mapping[str, Any] = Field(default_factory=dict, exclude=True) @model_validator(mode="after") def _load_legacy_bootstrap_inputs(self) -> VariablePool: """ Accept legacy constructor kwargs that still appear throughout the workflow layer while keeping serialized state focused on `variable_dictionary`. """ self._ingest_legacy_variables(self.system_variables, node_id=self._SYSTEM_VARIABLE_NODE_ID) self._ingest_legacy_variables(self.environment_variables, node_id=self._ENVIRONMENT_VARIABLE_NODE_ID) self._ingest_legacy_variables(self.conversation_variables, node_id=self._CONVERSATION_VARIABLE_NODE_ID) self._ingest_legacy_rag_variables(self.rag_pipeline_variables) # These kwargs are accepted for compatibility but should not affect the # stable serialized form or model equality. self.system_variables = () self.environment_variables = () self.conversation_variables = () self.rag_pipeline_variables = () self.user_inputs = {} return self def _ingest_legacy_variables(self, variables: Sequence[Variable], *, node_id: str) -> None: for variable in variables: selector = [node_id, variable.name] normalized_variable = variable if list(variable.selector) != selector: normalized_variable = variable.model_copy(update={"selector": selector}) self.add(normalized_variable.selector, normalized_variable) def _ingest_legacy_rag_variables(self, rag_pipeline_variables: Sequence[RAGPipelineVariableInput]) -> None: if not rag_pipeline_variables: return values_by_node_id: defaultdict[str, dict[str, Any]] = defaultdict(dict) for rag_variable_input in rag_pipeline_variables: values_by_node_id[rag_variable_input.variable.belong_to_node_id][rag_variable_input.variable.variable] = ( rag_variable_input.value ) for node_id, value in values_by_node_id.items(): self.add((self._RAG_PIPELINE_VARIABLE_NODE_ID, node_id), value) def add(self, selector: Sequence[str], value: Any, /): """ Add a variable to the variable pool. This method accepts a selector path and a value, converting the value to a Variable object if necessary before storing it in the pool. Args: selector: A two-element sequence containing [node_id, variable_name]. The selector must have exactly 2 elements to be valid. value: The value to store. Can be a Variable, Segment, or any value that can be converted to a Segment (str, int, float, dict, list, File). Raises: ValueError: If selector length is not exactly 2 elements. Note: While non-Segment values are currently accepted and automatically converted, it's recommended to pass Segment or Variable objects directly. """ if len(selector) != SELECTORS_LENGTH: raise ValueError( f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), " f"got {len(selector)} elements" ) if isinstance(value, VariableBase): variable = value elif isinstance(value, Segment): variable = segment_to_variable(segment=value, selector=selector) else: segment = build_segment(value) variable = segment_to_variable(segment=segment, selector=selector) node_id, name = self._selector_to_keys(selector) # Based on the definition of `Variable`, # `VariableBase` instances can be safely used as `Variable` since they are compatible. self.variable_dictionary[node_id][name] = cast(Variable, variable) @classmethod def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]: return selector[0], selector[1] def _has(self, selector: Sequence[str]) -> bool: node_id, name = self._selector_to_keys(selector) if node_id not in self.variable_dictionary: return False if name not in self.variable_dictionary[node_id]: return False return True def get(self, selector: Sequence[str], /) -> Segment | None: """ Retrieve a variable's value from the pool as a Segment. This method supports both simple selectors [node_id, variable_name] and extended selectors that include attribute access for FileSegment and ObjectSegment types. Args: selector: A sequence with at least 2 elements: - [node_id, variable_name]: Returns the full segment - [node_id, variable_name, attr, ...]: Returns a nested value from FileSegment (e.g., 'url', 'name') or ObjectSegment Returns: The Segment associated with the selector, or None if not found. Returns None if selector has fewer than 2 elements. Raises: ValueError: If attempting to access an invalid FileAttribute. """ if len(selector) < SELECTORS_LENGTH: return None node_id, name = self._selector_to_keys(selector) node_map = self.variable_dictionary.get(node_id) if node_map is None: return None segment: Segment | None = node_map.get(name) if segment is None: return None if len(selector) == 2: return segment if isinstance(segment, FileSegment): attr = selector[2] # Python support `attr in FileAttribute` after 3.12 if attr not in {item.value for item in FileAttribute}: return None attr = FileAttribute(attr) attr_value = file_manager.get_attr(file=segment.value, attr=attr) return build_segment(attr_value) # Navigate through nested attributes result: Any = segment for attr in selector[2:]: result = self._extract_value(result) result = self._get_nested_attribute(result, attr) if result is None: return None # Return result as Segment return result if isinstance(result, Segment) else build_segment(result) def _extract_value(self, obj: Any): """Extract the actual value from an ObjectSegment.""" return obj.value if isinstance(obj, ObjectSegment) else obj def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None: """ Get a nested attribute from a dictionary-like object. Args: obj: The dictionary-like object to search. attr: The key to look up. Returns: Segment | None: The corresponding Segment built from the attribute value if the key exists, otherwise None. """ if not isinstance(obj, dict) or attr not in obj: return None return build_segment(obj.get(attr)) def remove(self, selector: Sequence[str], /): """ Remove variables from the variable pool based on the given selector. Args: selector (Sequence[str]): A sequence of strings representing the selector. Returns: None """ if not selector: return if len(selector) == 1: self.variable_dictionary[selector[0]] = {} return key, hash_key = self._selector_to_keys(selector) self.variable_dictionary[key].pop(hash_key, None) def convert_template(self, template: str, /): parts = VARIABLE_PATTERN.split(template) segments: list[Segment] = [] for part in filter(lambda x: x, parts): if "." in part and (variable := self.get(part.split("."))): segments.append(variable) else: segments.append(build_segment(part)) return SegmentGroup(value=segments) def get_file(self, selector: Sequence[str], /) -> FileSegment | None: segment = self.get(selector) if isinstance(segment, FileSegment): return segment return None def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: """Return a copy of all variables stored under the given node prefix.""" nodes = self.variable_dictionary.get(prefix) if not nodes: return {} result: dict[str, object] = {} for key, variable in nodes.items(): value = variable.value result[key] = deepcopy(value) return result def flatten(self, *, unprefixed_node_id: str | None = None) -> Mapping[str, object]: """Return a selector-style snapshot of the entire variable pool.""" result: dict[str, object] = {} for node_id, variables in self.variable_dictionary.items(): for name, variable in variables.items(): output_name = name if node_id == unprefixed_node_id else f"{node_id}.{name}" result[output_name] = deepcopy(variable.value) return result @classmethod def empty(cls) -> VariablePool: """Create an empty variable pool.""" return cls()