Files
dify/api/core/workflow/system_variables.py

237 lines
7.9 KiB
Python

from __future__ import annotations
from collections import defaultdict
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any, Protocol, cast
from uuid import uuid4
from graphon.enums import BuiltinNodeTypes
from graphon.variables import build_segment, segment_to_variable
from graphon.variables.segments import Segment
from graphon.variables.variables import RAGPipelineVariableInput, Variable
from .variable_prefixes import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
RAG_PIPELINE_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
class SystemVariableKey(StrEnum):
QUERY = "query"
FILES = "files"
CONVERSATION_ID = "conversation_id"
USER_ID = "user_id"
DIALOGUE_COUNT = "dialogue_count"
APP_ID = "app_id"
WORKFLOW_ID = "workflow_id"
WORKFLOW_EXECUTION_ID = "workflow_run_id"
TIMESTAMP = "timestamp"
DOCUMENT_ID = "document_id"
ORIGINAL_DOCUMENT_ID = "original_document_id"
BATCH = "batch"
DATASET_ID = "dataset_id"
DATASOURCE_TYPE = "datasource_type"
DATASOURCE_INFO = "datasource_info"
INVOKE_FROM = "invoke_from"
class _VariablePoolReader(Protocol):
def get(self, selector: Sequence[str], /) -> Segment | None: ...
def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: ...
class _VariablePoolWriter(_VariablePoolReader, Protocol):
def add(self, selector: Sequence[str], value: object, /) -> None: ...
class _VariableLoader(Protocol):
def load_variables(self, selectors: list[list[str]]) -> Sequence[object]: ...
def system_variable_name(key: str | SystemVariableKey) -> str:
return key.value if isinstance(key, SystemVariableKey) else key
def system_variable_selector(key: str | SystemVariableKey) -> tuple[str, str]:
return SYSTEM_VARIABLE_NODE_ID, system_variable_name(key)
def _normalize_system_variable_values(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> dict[str, Any]:
raw_values = dict(values or {})
raw_values.update(kwargs)
workflow_execution_id = raw_values.pop("workflow_execution_id", None)
if workflow_execution_id is not None and SystemVariableKey.WORKFLOW_EXECUTION_ID.value not in raw_values:
raw_values[SystemVariableKey.WORKFLOW_EXECUTION_ID.value] = workflow_execution_id
normalized: dict[str, Any] = {}
for key, value in raw_values.items():
if value is None:
continue
normalized[system_variable_name(key)] = value
normalized.setdefault(SystemVariableKey.FILES.value, [])
return normalized
def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> list[Variable]:
normalized = _normalize_system_variable_values(values, **kwargs)
return [
cast(
Variable,
segment_to_variable(
segment=build_segment(value),
selector=system_variable_selector(key),
name=key,
),
)
for key, value in normalized.items()
]
def default_system_variables() -> list[Variable]:
return build_system_variables(workflow_run_id=str(uuid4()))
def system_variables_to_mapping(system_variables: Sequence[Variable]) -> dict[str, Any]:
return {variable.name: variable.value for variable in system_variables}
def _with_selector(variable: Variable, node_id: str) -> Variable:
selector = [node_id, variable.name]
if list(variable.selector) == selector:
return variable
return variable.model_copy(update={"selector": selector})
def build_bootstrap_variables(
*,
system_variables: Sequence[Variable] = (),
environment_variables: Sequence[Variable] = (),
conversation_variables: Sequence[Variable] = (),
rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = (),
) -> list[Variable]:
variables = [
*(_with_selector(variable, SYSTEM_VARIABLE_NODE_ID) for variable in system_variables),
*(_with_selector(variable, ENVIRONMENT_VARIABLE_NODE_ID) for variable in environment_variables),
*(_with_selector(variable, CONVERSATION_VARIABLE_NODE_ID) for variable in conversation_variables),
]
rag_pipeline_variables_map: defaultdict[str, dict[str, Any]] = defaultdict(dict)
for rag_var in rag_pipeline_variables:
node_id = rag_var.variable.belong_to_node_id
key = rag_var.variable.variable
rag_pipeline_variables_map[node_id][key] = rag_var.value
for node_id, value in rag_pipeline_variables_map.items():
variables.append(
cast(
Variable,
segment_to_variable(
segment=build_segment(value),
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
name=node_id,
),
)
)
return variables
def get_system_segment(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Segment | None:
return variable_pool.get(system_variable_selector(key))
def get_system_value(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Any:
segment = get_system_segment(variable_pool, key)
return None if segment is None else segment.value
def get_system_text(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> str | None:
segment = get_system_segment(variable_pool, key)
if segment is None:
return None
text = getattr(segment, "text", None)
return text if isinstance(text, str) else None
def get_all_system_variables(variable_pool: _VariablePoolReader) -> Mapping[str, object]:
return variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
_MEMORY_BOOTSTRAP_NODE_TYPES = frozenset(
(
BuiltinNodeTypes.LLM,
BuiltinNodeTypes.QUESTION_CLASSIFIER,
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
)
)
def get_node_creation_preload_selectors(
*,
node_type: str,
node_data: object,
) -> tuple[tuple[str, str], ...]:
"""Return selectors that must exist before node construction begins."""
if node_type not in _MEMORY_BOOTSTRAP_NODE_TYPES or getattr(node_data, "memory", None) is None:
return ()
return (system_variable_selector(SystemVariableKey.CONVERSATION_ID),)
def preload_node_creation_variables(
*,
variable_loader: _VariableLoader,
variable_pool: _VariablePoolWriter,
selectors: Sequence[Sequence[str]],
) -> None:
"""Load constructor-time variables before node or graph creation."""
seen_selectors: set[tuple[str, ...]] = set()
selectors_to_load: list[list[str]] = []
for selector in selectors:
normalized_selector = tuple(selector)
if len(normalized_selector) < 2:
raise ValueError(f"Invalid preload selector: {selector}")
if normalized_selector in seen_selectors:
continue
seen_selectors.add(normalized_selector)
if variable_pool.get(normalized_selector) is None:
selectors_to_load.append(list(normalized_selector))
loaded_variables = variable_loader.load_variables(selectors_to_load)
for variable in loaded_variables:
raw_selector = getattr(variable, "selector", ())
loaded_selector = list(raw_selector)
if len(loaded_selector) < 2:
raise ValueError(f"Invalid loaded variable selector: {raw_selector}")
variable_pool.add(loaded_selector[:2], variable)
def inject_default_system_variable_mappings(
*,
node_id: str,
node_type: str,
node_data: object,
variable_mapping: Mapping[str, Sequence[str]],
) -> Mapping[str, Sequence[str]]:
"""Add workflow-owned implicit sys mappings that `graphon` should not know about."""
if node_type != BuiltinNodeTypes.LLM or getattr(node_data, "memory", None) is None:
return variable_mapping
query_mapping_key = f"{node_id}.#sys.query#"
if query_mapping_key in variable_mapping:
return variable_mapping
augmented_mapping = dict(variable_mapping)
augmented_mapping[query_mapping_key] = system_variable_selector(SystemVariableKey.QUERY)
return augmented_mapping