Files
dify/api/dify_graph/nodes/loop/entities.py
2026-03-18 19:16:48 +08:00

169 lines
5.2 KiB
Python

from __future__ import annotations
from enum import StrEnum
from typing import Annotated, Any, Literal, TypeAlias, cast
from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
from dify_graph.utils.condition.entities import Condition
from dify_graph.variables.types import SegmentType
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
LoopValueMapping: TypeAlias = dict[str, LoopValue]
VariableSelector: TypeAlias = list[str]
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
_VALID_VAR_TYPE = frozenset(
[
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.OBJECT,
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
]
)
def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
if seg_type not in _VALID_VAR_TYPE:
raise ValueError(...)
return seg_type
def _validate_loop_value(value: object) -> LoopValue:
if value is None or isinstance(value, (str, int, float, bool)):
return cast(LoopValue, value)
if isinstance(value, list):
return [_validate_loop_value(item) for item in value]
if isinstance(value, dict):
normalized: dict[str, LoopValue] = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("Loop values only support string object keys")
normalized[key] = _validate_loop_value(item)
return normalized
raise TypeError("Loop values must be JSON-like primitives, arrays, or objects")
def _validate_loop_value_mapping(value: object) -> LoopValueMapping:
if not isinstance(value, dict):
raise TypeError("Loop outputs must be an object")
normalized: LoopValueMapping = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("Loop output keys must be strings")
normalized[key] = _validate_loop_value(item)
return normalized
class LoopVariableData(BaseModel):
"""
Loop Variable Data.
"""
label: str
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: LoopValue | VariableSelector | None = None
@field_validator("value", mode="before")
@classmethod
def validate_value(cls, value: object, validation_info: ValidationInfo) -> LoopValue | VariableSelector | None:
value_type = validation_info.data.get("value_type")
if value_type == "variable":
if value is None:
raise ValueError("Variable loop inputs require a selector")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if value_type == "constant":
return _validate_loop_value(value)
raise ValueError(f"Unknown loop variable value type: {value_type}")
def require_variable_selector(self) -> VariableSelector:
if self.value_type != "variable":
raise ValueError(f"Expected variable loop input, got {self.value_type}")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
def require_constant_value(self) -> LoopValue:
if self.value_type != "constant":
raise ValueError(f"Expected constant loop input, got {self.value_type}")
return _validate_loop_value(self.value)
class LoopNodeData(BaseLoopNodeData):
type: NodeType = BuiltinNodeTypes.LOOP
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
outputs: LoopValueMapping = Field(default_factory=dict)
@field_validator("outputs", mode="before")
@classmethod
def validate_outputs(cls, value: object) -> LoopValueMapping:
if value is None:
return {}
return _validate_loop_value_mapping(value)
class LoopStartNodeData(BaseNodeData):
"""
Loop Start Node Data.
"""
type: NodeType = BuiltinNodeTypes.LOOP_START
class LoopEndNodeData(BaseNodeData):
"""
Loop End Node Data.
"""
type: NodeType = BuiltinNodeTypes.LOOP_END
class LoopState(BaseLoopState):
"""
Loop State.
"""
outputs: list[LoopValue] = Field(default_factory=list)
current_output: LoopValue | None = None
class MetaData(BaseLoopState.MetaData):
"""
Data.
"""
loop_length: int
def get_last_output(self) -> LoopValue | None:
"""
Get last output.
"""
if self.outputs:
return self.outputs[-1]
return None
def get_current_output(self) -> LoopValue | None:
"""
Get current output.
"""
return self.current_output
class LoopCompletedReason(StrEnum):
LOOP_BREAK = "loop_break"
LOOP_COMPLETED = "loop_completed"