mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 07:58:02 +08:00
refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)
refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025) This PR addresses serialization issues in the VariablePool model by separating the `value_type` tags for `IntegerSegment`/`FloatSegment` and `IntegerVariable`/`FloatVariable`. Previously, both Integer and Float types shared the same `SegmentType.NUMBER` tag, causing conflicts during serialization. Key changes: - Introduce distinct `value_type` tags for Integer and Float segments/variables - Add `VariableUnion` and `SegmentUnion` types for proper type discrimination - Leverage Pydantic's discriminated union feature for seamless serialization/deserialization - Enable accurate serialization of data structures containing these types Closes #22024.
This commit is contained in:
@ -1,9 +1,9 @@
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import Annotated, Any, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
|
||||
|
||||
from core.file import File
|
||||
|
||||
@ -11,6 +11,11 @@ from .types import SegmentType
|
||||
|
||||
|
||||
class Segment(BaseModel):
|
||||
"""Segment is runtime type used during the execution of workflow.
|
||||
|
||||
Note: this class is abstract, you should use subclasses of this class instead.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
value_type: SegmentType
|
||||
@ -73,7 +78,7 @@ class StringSegment(Segment):
|
||||
|
||||
|
||||
class FloatSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.NUMBER
|
||||
value_type: SegmentType = SegmentType.FLOAT
|
||||
value: float
|
||||
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
||||
# The following tests cannot pass.
|
||||
@ -92,7 +97,7 @@ class FloatSegment(Segment):
|
||||
|
||||
|
||||
class IntegerSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.NUMBER
|
||||
value_type: SegmentType = SegmentType.INTEGER
|
||||
value: int
|
||||
|
||||
|
||||
@ -181,3 +186,46 @@ class ArrayFileSegment(ArraySegment):
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||
if isinstance(v, Segment):
|
||||
return v.value_type
|
||||
elif isinstance(v, dict):
|
||||
value_type = v.get("value_type")
|
||||
if value_type is None:
|
||||
return None
|
||||
try:
|
||||
seg_type = SegmentType(value_type)
|
||||
except ValueError:
|
||||
return None
|
||||
return seg_type
|
||||
else:
|
||||
# return None if the discriminator value isn't found
|
||||
return None
|
||||
|
||||
|
||||
# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic.
|
||||
# Use `Segment` for type hinting when serialization is not required.
|
||||
#
|
||||
# Note:
|
||||
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
|
||||
# - The union must include all non-abstract subclasses of `Segment`, except:
|
||||
# - `SegmentGroup`, which is not added to the variable pool.
|
||||
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
|
||||
SegmentUnion: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[NoneSegment, Tag(SegmentType.NONE)]
|
||||
| Annotated[StringSegment, Tag(SegmentType.STRING)]
|
||||
| Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
|
||||
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
|
||||
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
|
||||
| Annotated[FileSegment, Tag(SegmentType.FILE)]
|
||||
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
|
||||
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
|
||||
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
|
||||
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
]
|
||||
|
||||
@ -1,8 +1,27 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.file.models import File
|
||||
|
||||
|
||||
class ArrayValidation(StrEnum):
|
||||
"""Strategy for validating array elements"""
|
||||
|
||||
# Skip element validation (only check array container)
|
||||
NONE = "none"
|
||||
|
||||
# Validate the first element (if array is non-empty)
|
||||
FIRST = "first"
|
||||
|
||||
# Validate all elements in the array.
|
||||
ALL = "all"
|
||||
|
||||
|
||||
class SegmentType(StrEnum):
|
||||
NUMBER = "number"
|
||||
INTEGER = "integer"
|
||||
FLOAT = "float"
|
||||
STRING = "string"
|
||||
OBJECT = "object"
|
||||
SECRET = "secret"
|
||||
@ -19,16 +38,141 @@ class SegmentType(StrEnum):
|
||||
|
||||
GROUP = "group"
|
||||
|
||||
def is_array_type(self):
|
||||
def is_array_type(self) -> bool:
|
||||
return self in _ARRAY_TYPES
|
||||
|
||||
@classmethod
|
||||
def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
|
||||
"""
|
||||
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
|
||||
|
||||
Returns `None` if no appropriate `SegmentType` can be determined for the given `value`.
|
||||
For example, this may occur if the input is a generic Python object of type `object`.
|
||||
"""
|
||||
|
||||
if isinstance(value, list):
|
||||
elem_types: set[SegmentType] = set()
|
||||
for i in value:
|
||||
segment_type = cls.infer_segment_type(i)
|
||||
if segment_type is None:
|
||||
return None
|
||||
|
||||
elem_types.add(segment_type)
|
||||
|
||||
if len(elem_types) != 1:
|
||||
if elem_types.issubset(_NUMERICAL_TYPES):
|
||||
return SegmentType.ARRAY_NUMBER
|
||||
return SegmentType.ARRAY_ANY
|
||||
elif all(i.is_array_type() for i in elem_types):
|
||||
return SegmentType.ARRAY_ANY
|
||||
match elem_types.pop():
|
||||
case SegmentType.STRING:
|
||||
return SegmentType.ARRAY_STRING
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
return SegmentType.ARRAY_NUMBER
|
||||
case SegmentType.OBJECT:
|
||||
return SegmentType.ARRAY_OBJECT
|
||||
case SegmentType.FILE:
|
||||
return SegmentType.ARRAY_FILE
|
||||
case SegmentType.NONE:
|
||||
return SegmentType.ARRAY_ANY
|
||||
case _:
|
||||
# This should be unreachable.
|
||||
raise ValueError(f"not supported value {value}")
|
||||
if value is None:
|
||||
return SegmentType.NONE
|
||||
elif isinstance(value, int) and not isinstance(value, bool):
|
||||
return SegmentType.INTEGER
|
||||
elif isinstance(value, float):
|
||||
return SegmentType.FLOAT
|
||||
elif isinstance(value, str):
|
||||
return SegmentType.STRING
|
||||
elif isinstance(value, dict):
|
||||
return SegmentType.OBJECT
|
||||
elif isinstance(value, File):
|
||||
return SegmentType.FILE
|
||||
elif isinstance(value, str):
|
||||
return SegmentType.STRING
|
||||
else:
|
||||
return None
|
||||
|
||||
def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool:
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
# Skip element validation if array is empty
|
||||
if len(value) == 0:
|
||||
return True
|
||||
if self == SegmentType.ARRAY_ANY:
|
||||
return True
|
||||
element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self]
|
||||
|
||||
if array_validation == ArrayValidation.NONE:
|
||||
return True
|
||||
elif array_validation == ArrayValidation.FIRST:
|
||||
return element_type.is_valid(value[0])
|
||||
else:
|
||||
return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value)
|
||||
|
||||
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
|
||||
"""
|
||||
Check if a value matches the segment type.
|
||||
Users of `SegmentType` should call this method, instead of using
|
||||
`isinstance` manually.
|
||||
|
||||
Args:
|
||||
value: The value to validate
|
||||
array_validation: Validation strategy for array types (ignored for non-array types)
|
||||
|
||||
Returns:
|
||||
True if the value matches the type under the given validation strategy
|
||||
"""
|
||||
if self.is_array_type():
|
||||
return self._validate_array(value, array_validation)
|
||||
elif self == SegmentType.NUMBER:
|
||||
return isinstance(value, (int, float))
|
||||
elif self == SegmentType.STRING:
|
||||
return isinstance(value, str)
|
||||
elif self == SegmentType.OBJECT:
|
||||
return isinstance(value, dict)
|
||||
elif self == SegmentType.SECRET:
|
||||
return isinstance(value, str)
|
||||
elif self == SegmentType.FILE:
|
||||
return isinstance(value, File)
|
||||
elif self == SegmentType.NONE:
|
||||
return value is None
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
def exposed_type(self) -> "SegmentType":
|
||||
"""Returns the type exposed to the frontend.
|
||||
|
||||
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
|
||||
"""
|
||||
if self in (SegmentType.INTEGER, SegmentType.FLOAT):
|
||||
return SegmentType.NUMBER
|
||||
return self
|
||||
|
||||
|
||||
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
|
||||
# ARRAY_ANY does not have correpond element type.
|
||||
SegmentType.ARRAY_STRING: SegmentType.STRING,
|
||||
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
|
||||
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
|
||||
SegmentType.ARRAY_FILE: SegmentType.FILE,
|
||||
}
|
||||
|
||||
_ARRAY_TYPES = frozenset(
|
||||
[
|
||||
list(_ARRAY_ELEMENT_TYPES_MAPPING.keys())
|
||||
+ [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
_NUMERICAL_TYPES = frozenset(
|
||||
[
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.FLOAT,
|
||||
]
|
||||
)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
from typing import Annotated, TypeAlias, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Discriminator, Field, Tag
|
||||
|
||||
from core.helper import encrypter
|
||||
|
||||
@ -20,6 +20,7 @@ from .segments import (
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
StringSegment,
|
||||
get_segment_discriminator,
|
||||
)
|
||||
from .types import SegmentType
|
||||
|
||||
@ -27,6 +28,10 @@ from .types import SegmentType
|
||||
class Variable(Segment):
|
||||
"""
|
||||
A variable is a segment that has a name.
|
||||
|
||||
It is mainly used to store segments and their selector in VariablePool.
|
||||
|
||||
Note: this class is abstract, you should use subclasses of this class instead.
|
||||
"""
|
||||
|
||||
id: str = Field(
|
||||
@ -93,3 +98,28 @@ class FileVariable(FileSegment, Variable):
|
||||
|
||||
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
|
||||
# Use `Variable` for type hinting when serialization is not required.
|
||||
#
|
||||
# Note:
|
||||
# - All variants in `VariableUnion` must inherit from the `Variable` class.
|
||||
# - The union must include all non-abstract subclasses of `Segment`, except:
|
||||
VariableUnion: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[NoneVariable, Tag(SegmentType.NONE)]
|
||||
| Annotated[StringVariable, Tag(SegmentType.STRING)]
|
||||
| Annotated[FloatVariable, Tag(SegmentType.FLOAT)]
|
||||
| Annotated[IntegerVariable, Tag(SegmentType.INTEGER)]
|
||||
| Annotated[ObjectVariable, Tag(SegmentType.OBJECT)]
|
||||
| Annotated[FileVariable, Tag(SegmentType.FILE)]
|
||||
| Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)]
|
||||
| Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)]
|
||||
| Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)]
|
||||
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user