mirror of
https://github.com/langgenius/dify.git
synced 2026-04-19 18:27:27 +08:00
169 lines
5.2 KiB
Python
169 lines
5.2 KiB
Python
from __future__ import annotations
|
|
|
|
from enum import StrEnum
|
|
from typing import Annotated, Any, Literal, TypeAlias, cast
|
|
|
|
from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator
|
|
from pydantic_core.core_schema import ValidationInfo
|
|
|
|
from dify_graph.entities.base_node_data import BaseNodeData
|
|
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
|
from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
|
|
from dify_graph.utils.condition.entities import Condition
|
|
from dify_graph.variables.types import SegmentType
|
|
|
|
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
|
LoopValueMapping: TypeAlias = dict[str, LoopValue]
|
|
VariableSelector: TypeAlias = list[str]
|
|
|
|
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
|
|
|
_VALID_VAR_TYPE = frozenset(
|
|
[
|
|
SegmentType.STRING,
|
|
SegmentType.NUMBER,
|
|
SegmentType.OBJECT,
|
|
SegmentType.BOOLEAN,
|
|
SegmentType.ARRAY_STRING,
|
|
SegmentType.ARRAY_NUMBER,
|
|
SegmentType.ARRAY_OBJECT,
|
|
SegmentType.ARRAY_BOOLEAN,
|
|
]
|
|
)
|
|
|
|
|
|
def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
|
|
if seg_type not in _VALID_VAR_TYPE:
|
|
raise ValueError(...)
|
|
return seg_type
|
|
|
|
|
|
def _validate_loop_value(value: object) -> LoopValue:
|
|
if value is None or isinstance(value, (str, int, float, bool)):
|
|
return cast(LoopValue, value)
|
|
|
|
if isinstance(value, list):
|
|
return [_validate_loop_value(item) for item in value]
|
|
|
|
if isinstance(value, dict):
|
|
normalized: dict[str, LoopValue] = {}
|
|
for key, item in value.items():
|
|
if not isinstance(key, str):
|
|
raise TypeError("Loop values only support string object keys")
|
|
normalized[key] = _validate_loop_value(item)
|
|
return normalized
|
|
|
|
raise TypeError("Loop values must be JSON-like primitives, arrays, or objects")
|
|
|
|
|
|
def _validate_loop_value_mapping(value: object) -> LoopValueMapping:
|
|
if not isinstance(value, dict):
|
|
raise TypeError("Loop outputs must be an object")
|
|
|
|
normalized: LoopValueMapping = {}
|
|
for key, item in value.items():
|
|
if not isinstance(key, str):
|
|
raise TypeError("Loop output keys must be strings")
|
|
normalized[key] = _validate_loop_value(item)
|
|
return normalized
|
|
|
|
|
|
class LoopVariableData(BaseModel):
|
|
"""
|
|
Loop Variable Data.
|
|
"""
|
|
|
|
label: str
|
|
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
|
|
value_type: Literal["variable", "constant"]
|
|
value: LoopValue | VariableSelector | None = None
|
|
|
|
@field_validator("value", mode="before")
|
|
@classmethod
|
|
def validate_value(cls, value: object, validation_info: ValidationInfo) -> LoopValue | VariableSelector | None:
|
|
value_type = validation_info.data.get("value_type")
|
|
if value_type == "variable":
|
|
if value is None:
|
|
raise ValueError("Variable loop inputs require a selector")
|
|
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
|
if value_type == "constant":
|
|
return _validate_loop_value(value)
|
|
raise ValueError(f"Unknown loop variable value type: {value_type}")
|
|
|
|
def require_variable_selector(self) -> VariableSelector:
|
|
if self.value_type != "variable":
|
|
raise ValueError(f"Expected variable loop input, got {self.value_type}")
|
|
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
|
|
|
|
def require_constant_value(self) -> LoopValue:
|
|
if self.value_type != "constant":
|
|
raise ValueError(f"Expected constant loop input, got {self.value_type}")
|
|
return _validate_loop_value(self.value)
|
|
|
|
|
|
class LoopNodeData(BaseLoopNodeData):
|
|
type: NodeType = BuiltinNodeTypes.LOOP
|
|
loop_count: int # Maximum number of loops
|
|
break_conditions: list[Condition] # Conditions to break the loop
|
|
logical_operator: Literal["and", "or"]
|
|
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
|
|
outputs: LoopValueMapping = Field(default_factory=dict)
|
|
|
|
@field_validator("outputs", mode="before")
|
|
@classmethod
|
|
def validate_outputs(cls, value: object) -> LoopValueMapping:
|
|
if value is None:
|
|
return {}
|
|
return _validate_loop_value_mapping(value)
|
|
|
|
|
|
class LoopStartNodeData(BaseNodeData):
|
|
"""
|
|
Loop Start Node Data.
|
|
"""
|
|
|
|
type: NodeType = BuiltinNodeTypes.LOOP_START
|
|
|
|
|
|
class LoopEndNodeData(BaseNodeData):
|
|
"""
|
|
Loop End Node Data.
|
|
"""
|
|
|
|
type: NodeType = BuiltinNodeTypes.LOOP_END
|
|
|
|
|
|
class LoopState(BaseLoopState):
|
|
"""
|
|
Loop State.
|
|
"""
|
|
|
|
outputs: list[LoopValue] = Field(default_factory=list)
|
|
current_output: LoopValue | None = None
|
|
|
|
class MetaData(BaseLoopState.MetaData):
|
|
"""
|
|
Data.
|
|
"""
|
|
|
|
loop_length: int
|
|
|
|
def get_last_output(self) -> LoopValue | None:
|
|
"""
|
|
Get last output.
|
|
"""
|
|
if self.outputs:
|
|
return self.outputs[-1]
|
|
return None
|
|
|
|
def get_current_output(self) -> LoopValue | None:
|
|
"""
|
|
Get current output.
|
|
"""
|
|
return self.current_output
|
|
|
|
|
|
class LoopCompletedReason(StrEnum):
|
|
LOOP_BREAK = "loop_break"
|
|
LOOP_COMPLETED = "loop_completed"
|