import json from collections.abc import Generator, Mapping, MutableMapping, Sequence from typing import TYPE_CHECKING, Any, cast from graphon.entities.graph_config import NodeConfigDict from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent, VariableUpdatedEvent from graphon.nodes.base.node import Node from graphon.nodes.variable_assigner.common import helpers as common_helpers from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError from graphon.variables import SegmentType, Variable, VariableBase from graphon.variables.consts import SELECTORS_LENGTH from . import helpers from .entities import VariableAssignerNodeData, VariableOperationItem from .enums import InputType, Operation from .exc import ( InputTypeNotSupportedError, InvalidDataError, InvalidInputValueError, OperationNotSupportedError, VariableNotFoundError, ) if TYPE_CHECKING: from graphon.entities import GraphInitParams from graphon.runtime import GraphRuntimeState def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): selector_str = ".".join(item.variable_selector) key = f"{node_id}.#{selector_str}#" mapping[key] = item.variable_selector def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): # Keep this in sync with the logic in _run methods... if item.input_type != InputType.VARIABLE: return selector = item.value if not isinstance(selector, list): raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}") if len(selector) < SELECTORS_LENGTH: raise InvalidDataError(f"selector too short, {node_id=}, {item=}") selector_str = ".".join(selector) key = f"{node_id}.#{selector_str}#" mapping[key] = selector class VariableAssignerNode(Node[VariableAssignerNodeData]): node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER def __init__( self, id: str, config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ): super().__init__( id=id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: """ Check if this Variable Assigner node blocks the output of specific variables. Returns True if this node updates any of the requested conversation variables. """ # Check each item in this Variable Assigner node for item in self.node_data.items: # Convert the item's variable_selector to tuple for comparison item_selector_tuple = tuple(item.variable_selector) # Check if this item updates any of the requested variables if item_selector_tuple in variable_selectors: return True return False @classmethod def version(cls) -> str: return "2" @classmethod def _extract_variable_selector_to_variable_mapping( cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData, ) -> Mapping[str, Sequence[str]]: var_mapping: dict[str, Sequence[str]] = {} for item in node_data.items: _target_mapping_from_item(var_mapping, node_id, item) _source_mapping_from_item(var_mapping, node_id, item) return var_mapping def _run(self) -> Generator[NodeEventBase, None, None]: inputs = self.node_data.model_dump() process_data: dict[str, Any] = {} # NOTE: This node has no outputs updated_variable_selectors: list[Sequence[str]] = [] # Preserve intra-node read-after-write behavior without mutating the shared pool # until the engine processes the emitted VariableUpdatedEvent instances. working_variable_pool = self.graph_runtime_state.variable_pool.model_copy(deep=True) try: for item in self.node_data.items: variable = working_variable_pool.get(item.variable_selector) # ==================== Validation Part # Check if variable exists if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=item.variable_selector) # Check if operation is supported if not helpers.is_operation_supported(variable_type=variable.value_type, operation=item.operation): raise OperationNotSupportedError(operation=item.operation, variable_type=variable.value_type) # Check if variable input is supported if item.input_type == InputType.VARIABLE and not helpers.is_variable_input_supported( operation=item.operation ): raise InputTypeNotSupportedError(input_type=InputType.VARIABLE, operation=item.operation) # Check if constant input is supported if item.input_type == InputType.CONSTANT and not helpers.is_constant_input_supported( variable_type=variable.value_type, operation=item.operation ): raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation) # Get value from variable pool input_value = item.value if ( item.input_type == InputType.VARIABLE and item.operation not in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST} and item.value is not None ): value = working_variable_pool.get(item.value) if value is None: raise VariableNotFoundError(variable_selector=item.value) # Skip if value is NoneSegment if value.value_type == SegmentType.NONE: continue input_value = value.value # If set string / bytes / bytearray to object, try convert string to object. if ( item.operation == Operation.SET and variable.value_type == SegmentType.OBJECT and isinstance(input_value, str | bytes | bytearray) ): try: input_value = json.loads(input_value) except json.JSONDecodeError: raise InvalidInputValueError(value=input_value) # Check if input value is valid if not helpers.is_input_value_valid( variable_type=variable.value_type, operation=item.operation, value=input_value ): raise InvalidInputValueError(value=input_value) # ==================== Execution Part updated_value = self._handle_item( variable=variable, operation=item.operation, value=input_value, ) updated_variable = variable.model_copy(update={"value": updated_value}) working_variable_pool.add(updated_variable.selector, updated_variable) updated_variable_selectors.append(updated_variable.selector) except VariableOperatorNodeError as e: yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=inputs, process_data=process_data, error=str(e), ) ) return # The `updated_variable_selectors` is a list contains list[str] which not hashable, # remove duplicated items while preserving the first update order. updated_variable_selectors = list(dict.fromkeys(map(tuple, updated_variable_selectors))) for selector in updated_variable_selectors: variable = working_variable_pool.get(selector) if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=selector) process_data[variable.name] = variable.value updated_variables = [ common_helpers.variable_to_processed_data(selector, seg) for selector in updated_variable_selectors if (seg := working_variable_pool.get(selector)) is not None ] process_data = common_helpers.set_updated_variables(process_data, updated_variables) for selector in updated_variable_selectors: variable = working_variable_pool.get(selector) if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=selector) yield VariableUpdatedEvent(variable=cast(Variable, variable)) yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, outputs={}, ) ) def _handle_item( self, *, variable: VariableBase, operation: Operation, value: Any, ): match operation: case Operation.OVER_WRITE: return value case Operation.CLEAR: return SegmentType.get_zero_value(variable.value_type).to_object() case Operation.APPEND: return variable.value + [value] case Operation.EXTEND: return variable.value + value case Operation.SET: return value case Operation.ADD: return variable.value + value case Operation.SUBTRACT: return variable.value - value case Operation.MULTIPLY: return variable.value * value case Operation.DIVIDE: return variable.value / value case Operation.REMOVE_FIRST: # If array is empty, do nothing if not variable.value: return variable.value return variable.value[1:] case Operation.REMOVE_LAST: # If array is empty, do nothing if not variable.value: return variable.value return variable.value[:-1]