mirror of
https://github.com/langgenius/dify.git
synced 2026-03-04 23:36:20 +08:00
105 lines
2.3 KiB
Python
105 lines
2.3 KiB
Python
from enum import StrEnum
|
|
from typing import Annotated, Any, Literal
|
|
|
|
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
|
|
|
from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
|
from dify_graph.utils.condition.entities import Condition
|
|
from dify_graph.variables.types import SegmentType
|
|
|
|
_VALID_VAR_TYPE = frozenset(
|
|
[
|
|
SegmentType.STRING,
|
|
SegmentType.NUMBER,
|
|
SegmentType.OBJECT,
|
|
SegmentType.BOOLEAN,
|
|
SegmentType.ARRAY_STRING,
|
|
SegmentType.ARRAY_NUMBER,
|
|
SegmentType.ARRAY_OBJECT,
|
|
SegmentType.ARRAY_BOOLEAN,
|
|
]
|
|
)
|
|
|
|
|
|
def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
|
|
if seg_type not in _VALID_VAR_TYPE:
|
|
raise ValueError(...)
|
|
return seg_type
|
|
|
|
|
|
class LoopVariableData(BaseModel):
|
|
"""
|
|
Loop Variable Data.
|
|
"""
|
|
|
|
label: str
|
|
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
|
|
value_type: Literal["variable", "constant"]
|
|
value: Any | list[str] | None = None
|
|
|
|
|
|
class LoopNodeData(BaseLoopNodeData):
|
|
loop_count: int # Maximum number of loops
|
|
break_conditions: list[Condition] # Conditions to break the loop
|
|
logical_operator: Literal["and", "or"]
|
|
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
|
|
outputs: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
@field_validator("outputs", mode="before")
|
|
@classmethod
|
|
def validate_outputs(cls, v):
|
|
if v is None:
|
|
return {}
|
|
return v
|
|
|
|
|
|
class LoopStartNodeData(BaseNodeData):
|
|
"""
|
|
Loop Start Node Data.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class LoopEndNodeData(BaseNodeData):
|
|
"""
|
|
Loop End Node Data.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class LoopState(BaseLoopState):
|
|
"""
|
|
Loop State.
|
|
"""
|
|
|
|
outputs: list[Any] = Field(default_factory=list)
|
|
current_output: Any = None
|
|
|
|
class MetaData(BaseLoopState.MetaData):
|
|
"""
|
|
Data.
|
|
"""
|
|
|
|
loop_length: int
|
|
|
|
def get_last_output(self) -> Any:
|
|
"""
|
|
Get last output.
|
|
"""
|
|
if self.outputs:
|
|
return self.outputs[-1]
|
|
return None
|
|
|
|
def get_current_output(self) -> Any:
|
|
"""
|
|
Get current output.
|
|
"""
|
|
return self.current_output
|
|
|
|
|
|
class LoopCompletedReason(StrEnum):
|
|
LOOP_BREAK = "loop_break"
|
|
LOOP_COMPLETED = "loop_completed"
|