mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)
refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025) This PR addresses serialization issues in the VariablePool model by separating the `value_type` tags for `IntegerSegment`/`FloatSegment` and `IntegerVariable`/`FloatVariable`. Previously, both Integer and Float types shared the same `SegmentType.NUMBER` tag, causing conflicts during serialization. Key changes: - Introduce distinct `value_type` tags for Integer and Float segments/variables - Add `VariableUnion` and `SegmentUnion` types for proper type discrimination - Leverage Pydantic's discriminated union feature for seamless serialization/deserialization - Enable accurate serialization of data structures containing these types Closes #22024.
This commit is contained in:
@ -68,13 +68,18 @@ def _create_pagination_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||||
|
value_type = workflow_draft_var.value_type
|
||||||
|
return value_type.exposed_type().value
|
||||||
|
|
||||||
|
|
||||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||||
"name": fields.String,
|
"name": fields.String,
|
||||||
"description": fields.String,
|
"description": fields.String,
|
||||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||||
"value_type": fields.String,
|
"value_type": fields.String(attribute=_serialize_variable_type),
|
||||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||||
"visible": fields.Boolean,
|
"visible": fields.Boolean,
|
||||||
}
|
}
|
||||||
@ -90,7 +95,7 @@ _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
|||||||
"name": fields.String,
|
"name": fields.String,
|
||||||
"description": fields.String,
|
"description": fields.String,
|
||||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||||
"value_type": fields.String,
|
"value_type": fields.String(attribute=_serialize_variable_type),
|
||||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||||
"visible": fields.Boolean,
|
"visible": fields.Boolean,
|
||||||
}
|
}
|
||||||
@ -396,7 +401,7 @@ class EnvironmentVariableCollectionApi(Resource):
|
|||||||
"name": v.name,
|
"name": v.name,
|
||||||
"description": v.description,
|
"description": v.description,
|
||||||
"selector": v.selector,
|
"selector": v.selector,
|
||||||
"value_type": v.value_type.value,
|
"value_type": v.value_type.exposed_type().value,
|
||||||
"value": v.value,
|
"value": v.value,
|
||||||
# Do not track edited for env vars.
|
# Do not track edited for env vars.
|
||||||
"edited": False,
|
"edited": False,
|
||||||
|
|||||||
@ -16,9 +16,10 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
)
|
)
|
||||||
from core.moderation.base import ModerationError
|
from core.moderation.base import ModerationError
|
||||||
|
from core.variables.variables import VariableUnion
|
||||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.variable_loader import VariableLoader
|
from core.workflow.variable_loader import VariableLoader
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -64,7 +65,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
if not workflow:
|
if not workflow:
|
||||||
raise ValueError("Workflow not initialized")
|
raise ValueError("Workflow not initialized")
|
||||||
|
|
||||||
user_id = None
|
user_id: str | None = None
|
||||||
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||||
if end_user:
|
if end_user:
|
||||||
@ -136,23 +137,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
# Create a variable pool.
|
# Create a variable pool.
|
||||||
system_inputs = {
|
system_inputs = SystemVariable(
|
||||||
SystemVariableKey.QUERY: query,
|
query=query,
|
||||||
SystemVariableKey.FILES: files,
|
files=files,
|
||||||
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
|
conversation_id=self.conversation.id,
|
||||||
SystemVariableKey.USER_ID: user_id,
|
user_id=user_id,
|
||||||
SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
|
dialogue_count=self._dialogue_count,
|
||||||
SystemVariableKey.APP_ID: app_config.app_id,
|
app_id=app_config.app_id,
|
||||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
workflow_id=app_config.workflow_id,
|
||||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id,
|
workflow_execution_id=self.application_generate_entity.workflow_run_id,
|
||||||
}
|
)
|
||||||
|
|
||||||
# init variable pool
|
# init variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables=system_inputs,
|
system_variables=system_inputs,
|
||||||
user_inputs=inputs,
|
user_inputs=inputs,
|
||||||
environment_variables=workflow.environment_variables,
|
environment_variables=workflow.environment_variables,
|
||||||
conversation_variables=conversation_variables,
|
# Based on the definition of `VariableUnion`,
|
||||||
|
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||||
|
conversation_variables=cast(list[VariableUnion], conversation_variables),
|
||||||
)
|
)
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
|
|||||||
@ -61,12 +61,12 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
|||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||||
from events.message_event import message_was_created
|
from events.message_event import message_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -116,16 +116,16 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
|
|
||||||
self._workflow_cycle_manager = WorkflowCycleManager(
|
self._workflow_cycle_manager = WorkflowCycleManager(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
workflow_system_variables={
|
workflow_system_variables=SystemVariable(
|
||||||
SystemVariableKey.QUERY: message.query,
|
query=message.query,
|
||||||
SystemVariableKey.FILES: application_generate_entity.files,
|
files=application_generate_entity.files,
|
||||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
conversation_id=conversation.id,
|
||||||
SystemVariableKey.USER_ID: user_session_id,
|
user_id=user_session_id,
|
||||||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
dialogue_count=dialogue_count,
|
||||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
workflow_id=workflow.id,
|
||||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id,
|
workflow_execution_id=application_generate_entity.workflow_run_id,
|
||||||
},
|
),
|
||||||
workflow_info=CycleManagerWorkflowInfo(
|
workflow_info=CycleManagerWorkflowInfo(
|
||||||
workflow_id=workflow.id,
|
workflow_id=workflow.id,
|
||||||
workflow_type=WorkflowType(workflow.type),
|
workflow_type=WorkflowType(workflow.type),
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
)
|
)
|
||||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.variable_loader import VariableLoader
|
from core.workflow.variable_loader import VariableLoader
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -95,13 +95,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
files = self.application_generate_entity.files
|
files = self.application_generate_entity.files
|
||||||
|
|
||||||
# Create a variable pool.
|
# Create a variable pool.
|
||||||
system_inputs = {
|
|
||||||
SystemVariableKey.FILES: files,
|
system_inputs = SystemVariable(
|
||||||
SystemVariableKey.USER_ID: user_id,
|
files=files,
|
||||||
SystemVariableKey.APP_ID: app_config.app_id,
|
user_id=user_id,
|
||||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
app_id=app_config.app_id,
|
||||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
|
workflow_id=app_config.workflow_id,
|
||||||
}
|
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||||
|
)
|
||||||
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables=system_inputs,
|
system_variables=system_inputs,
|
||||||
|
|||||||
@ -54,10 +54,10 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
|
|||||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
@ -107,13 +107,13 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
|
|
||||||
self._workflow_cycle_manager = WorkflowCycleManager(
|
self._workflow_cycle_manager = WorkflowCycleManager(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
workflow_system_variables={
|
workflow_system_variables=SystemVariable(
|
||||||
SystemVariableKey.FILES: application_generate_entity.files,
|
files=application_generate_entity.files,
|
||||||
SystemVariableKey.USER_ID: user_session_id,
|
user_id=user_session_id,
|
||||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
workflow_id=workflow.id,
|
||||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id,
|
workflow_execution_id=application_generate_entity.workflow_execution_id,
|
||||||
},
|
),
|
||||||
workflow_info=CycleManagerWorkflowInfo(
|
workflow_info=CycleManagerWorkflowInfo(
|
||||||
workflow_id=workflow.id,
|
workflow_id=workflow.id,
|
||||||
workflow_type=WorkflowType(workflow.type),
|
workflow_type=WorkflowType(workflow.type),
|
||||||
|
|||||||
@ -62,6 +62,7 @@ from core.workflow.graph_engine.entities.event import (
|
|||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -166,7 +167,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||||||
|
|
||||||
# init variable pool
|
# init variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={},
|
system_variables=SystemVariable.empty(),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=workflow.environment_variables,
|
environment_variables=workflow.environment_variables,
|
||||||
)
|
)
|
||||||
@ -263,7 +264,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||||||
|
|
||||||
# init variable pool
|
# init variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={},
|
system_variables=SystemVariable.empty(),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=workflow.environment_variables,
|
environment_variables=workflow.environment_variables,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -158,7 +158,7 @@ class AdvancedPromptTransform(PromptTransform):
|
|||||||
|
|
||||||
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
|
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
|
||||||
if self.with_variable_tmpl:
|
if self.with_variable_tmpl:
|
||||||
vp = VariablePool()
|
vp = VariablePool.empty()
|
||||||
for k, v in inputs.items():
|
for k, v in inputs.items():
|
||||||
if k.startswith("#"):
|
if k.startswith("#"):
|
||||||
vp.add(k[1:-1].split("."), v)
|
vp.add(k[1:-1].split("."), v)
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any
|
from typing import Annotated, Any, TypeAlias
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, field_validator
|
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
|
||||||
|
|
||||||
from core.file import File
|
from core.file import File
|
||||||
|
|
||||||
@ -11,6 +11,11 @@ from .types import SegmentType
|
|||||||
|
|
||||||
|
|
||||||
class Segment(BaseModel):
|
class Segment(BaseModel):
|
||||||
|
"""Segment is runtime type used during the execution of workflow.
|
||||||
|
|
||||||
|
Note: this class is abstract, you should use subclasses of this class instead.
|
||||||
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(frozen=True)
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
value_type: SegmentType
|
value_type: SegmentType
|
||||||
@ -73,7 +78,7 @@ class StringSegment(Segment):
|
|||||||
|
|
||||||
|
|
||||||
class FloatSegment(Segment):
|
class FloatSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.NUMBER
|
value_type: SegmentType = SegmentType.FLOAT
|
||||||
value: float
|
value: float
|
||||||
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
||||||
# The following tests cannot pass.
|
# The following tests cannot pass.
|
||||||
@ -92,7 +97,7 @@ class FloatSegment(Segment):
|
|||||||
|
|
||||||
|
|
||||||
class IntegerSegment(Segment):
|
class IntegerSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.NUMBER
|
value_type: SegmentType = SegmentType.INTEGER
|
||||||
value: int
|
value: int
|
||||||
|
|
||||||
|
|
||||||
@ -181,3 +186,46 @@ class ArrayFileSegment(ArraySegment):
|
|||||||
@property
|
@property
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||||
|
if isinstance(v, Segment):
|
||||||
|
return v.value_type
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
value_type = v.get("value_type")
|
||||||
|
if value_type is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
seg_type = SegmentType(value_type)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return seg_type
|
||||||
|
else:
|
||||||
|
# return None if the discriminator value isn't found
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic.
|
||||||
|
# Use `Segment` for type hinting when serialization is not required.
|
||||||
|
#
|
||||||
|
# Note:
|
||||||
|
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
|
||||||
|
# - The union must include all non-abstract subclasses of `Segment`, except:
|
||||||
|
# - `SegmentGroup`, which is not added to the variable pool.
|
||||||
|
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
|
||||||
|
SegmentUnion: TypeAlias = Annotated[
|
||||||
|
(
|
||||||
|
Annotated[NoneSegment, Tag(SegmentType.NONE)]
|
||||||
|
| Annotated[StringSegment, Tag(SegmentType.STRING)]
|
||||||
|
| Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
|
||||||
|
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
|
||||||
|
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
|
||||||
|
| Annotated[FileSegment, Tag(SegmentType.FILE)]
|
||||||
|
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
|
||||||
|
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
|
||||||
|
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
|
||||||
|
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
|
||||||
|
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
|
||||||
|
),
|
||||||
|
Discriminator(get_segment_discriminator),
|
||||||
|
]
|
||||||
|
|||||||
@ -1,8 +1,27 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from core.file.models import File
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayValidation(StrEnum):
|
||||||
|
"""Strategy for validating array elements"""
|
||||||
|
|
||||||
|
# Skip element validation (only check array container)
|
||||||
|
NONE = "none"
|
||||||
|
|
||||||
|
# Validate the first element (if array is non-empty)
|
||||||
|
FIRST = "first"
|
||||||
|
|
||||||
|
# Validate all elements in the array.
|
||||||
|
ALL = "all"
|
||||||
|
|
||||||
|
|
||||||
class SegmentType(StrEnum):
|
class SegmentType(StrEnum):
|
||||||
NUMBER = "number"
|
NUMBER = "number"
|
||||||
|
INTEGER = "integer"
|
||||||
|
FLOAT = "float"
|
||||||
STRING = "string"
|
STRING = "string"
|
||||||
OBJECT = "object"
|
OBJECT = "object"
|
||||||
SECRET = "secret"
|
SECRET = "secret"
|
||||||
@ -19,16 +38,141 @@ class SegmentType(StrEnum):
|
|||||||
|
|
||||||
GROUP = "group"
|
GROUP = "group"
|
||||||
|
|
||||||
def is_array_type(self):
|
def is_array_type(self) -> bool:
|
||||||
return self in _ARRAY_TYPES
|
return self in _ARRAY_TYPES
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
|
||||||
|
"""
|
||||||
|
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
|
||||||
|
|
||||||
|
Returns `None` if no appropriate `SegmentType` can be determined for the given `value`.
|
||||||
|
For example, this may occur if the input is a generic Python object of type `object`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(value, list):
|
||||||
|
elem_types: set[SegmentType] = set()
|
||||||
|
for i in value:
|
||||||
|
segment_type = cls.infer_segment_type(i)
|
||||||
|
if segment_type is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
elem_types.add(segment_type)
|
||||||
|
|
||||||
|
if len(elem_types) != 1:
|
||||||
|
if elem_types.issubset(_NUMERICAL_TYPES):
|
||||||
|
return SegmentType.ARRAY_NUMBER
|
||||||
|
return SegmentType.ARRAY_ANY
|
||||||
|
elif all(i.is_array_type() for i in elem_types):
|
||||||
|
return SegmentType.ARRAY_ANY
|
||||||
|
match elem_types.pop():
|
||||||
|
case SegmentType.STRING:
|
||||||
|
return SegmentType.ARRAY_STRING
|
||||||
|
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||||
|
return SegmentType.ARRAY_NUMBER
|
||||||
|
case SegmentType.OBJECT:
|
||||||
|
return SegmentType.ARRAY_OBJECT
|
||||||
|
case SegmentType.FILE:
|
||||||
|
return SegmentType.ARRAY_FILE
|
||||||
|
case SegmentType.NONE:
|
||||||
|
return SegmentType.ARRAY_ANY
|
||||||
|
case _:
|
||||||
|
# This should be unreachable.
|
||||||
|
raise ValueError(f"not supported value {value}")
|
||||||
|
if value is None:
|
||||||
|
return SegmentType.NONE
|
||||||
|
elif isinstance(value, int) and not isinstance(value, bool):
|
||||||
|
return SegmentType.INTEGER
|
||||||
|
elif isinstance(value, float):
|
||||||
|
return SegmentType.FLOAT
|
||||||
|
elif isinstance(value, str):
|
||||||
|
return SegmentType.STRING
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
return SegmentType.OBJECT
|
||||||
|
elif isinstance(value, File):
|
||||||
|
return SegmentType.FILE
|
||||||
|
elif isinstance(value, str):
|
||||||
|
return SegmentType.STRING
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool:
|
||||||
|
if not isinstance(value, list):
|
||||||
|
return False
|
||||||
|
# Skip element validation if array is empty
|
||||||
|
if len(value) == 0:
|
||||||
|
return True
|
||||||
|
if self == SegmentType.ARRAY_ANY:
|
||||||
|
return True
|
||||||
|
element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self]
|
||||||
|
|
||||||
|
if array_validation == ArrayValidation.NONE:
|
||||||
|
return True
|
||||||
|
elif array_validation == ArrayValidation.FIRST:
|
||||||
|
return element_type.is_valid(value[0])
|
||||||
|
else:
|
||||||
|
return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value)
|
||||||
|
|
||||||
|
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a value matches the segment type.
|
||||||
|
Users of `SegmentType` should call this method, instead of using
|
||||||
|
`isinstance` manually.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to validate
|
||||||
|
array_validation: Validation strategy for array types (ignored for non-array types)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the value matches the type under the given validation strategy
|
||||||
|
"""
|
||||||
|
if self.is_array_type():
|
||||||
|
return self._validate_array(value, array_validation)
|
||||||
|
elif self == SegmentType.NUMBER:
|
||||||
|
return isinstance(value, (int, float))
|
||||||
|
elif self == SegmentType.STRING:
|
||||||
|
return isinstance(value, str)
|
||||||
|
elif self == SegmentType.OBJECT:
|
||||||
|
return isinstance(value, dict)
|
||||||
|
elif self == SegmentType.SECRET:
|
||||||
|
return isinstance(value, str)
|
||||||
|
elif self == SegmentType.FILE:
|
||||||
|
return isinstance(value, File)
|
||||||
|
elif self == SegmentType.NONE:
|
||||||
|
return value is None
|
||||||
|
else:
|
||||||
|
raise AssertionError("this statement should be unreachable.")
|
||||||
|
|
||||||
|
def exposed_type(self) -> "SegmentType":
|
||||||
|
"""Returns the type exposed to the frontend.
|
||||||
|
|
||||||
|
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
|
||||||
|
"""
|
||||||
|
if self in (SegmentType.INTEGER, SegmentType.FLOAT):
|
||||||
|
return SegmentType.NUMBER
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
|
||||||
|
# ARRAY_ANY does not have correpond element type.
|
||||||
|
SegmentType.ARRAY_STRING: SegmentType.STRING,
|
||||||
|
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
|
||||||
|
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
|
||||||
|
SegmentType.ARRAY_FILE: SegmentType.FILE,
|
||||||
|
}
|
||||||
|
|
||||||
_ARRAY_TYPES = frozenset(
|
_ARRAY_TYPES = frozenset(
|
||||||
[
|
list(_ARRAY_ELEMENT_TYPES_MAPPING.keys())
|
||||||
|
+ [
|
||||||
SegmentType.ARRAY_ANY,
|
SegmentType.ARRAY_ANY,
|
||||||
SegmentType.ARRAY_STRING,
|
]
|
||||||
SegmentType.ARRAY_NUMBER,
|
)
|
||||||
SegmentType.ARRAY_OBJECT,
|
|
||||||
SegmentType.ARRAY_FILE,
|
|
||||||
|
_NUMERICAL_TYPES = frozenset(
|
||||||
|
[
|
||||||
|
SegmentType.NUMBER,
|
||||||
|
SegmentType.INTEGER,
|
||||||
|
SegmentType.FLOAT,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import cast
|
from typing import Annotated, TypeAlias, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Discriminator, Field, Tag
|
||||||
|
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
|
|
||||||
@ -20,6 +20,7 @@ from .segments import (
|
|||||||
ObjectSegment,
|
ObjectSegment,
|
||||||
Segment,
|
Segment,
|
||||||
StringSegment,
|
StringSegment,
|
||||||
|
get_segment_discriminator,
|
||||||
)
|
)
|
||||||
from .types import SegmentType
|
from .types import SegmentType
|
||||||
|
|
||||||
@ -27,6 +28,10 @@ from .types import SegmentType
|
|||||||
class Variable(Segment):
|
class Variable(Segment):
|
||||||
"""
|
"""
|
||||||
A variable is a segment that has a name.
|
A variable is a segment that has a name.
|
||||||
|
|
||||||
|
It is mainly used to store segments and their selector in VariablePool.
|
||||||
|
|
||||||
|
Note: this class is abstract, you should use subclasses of this class instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str = Field(
|
id: str = Field(
|
||||||
@ -93,3 +98,28 @@ class FileVariable(FileSegment, Variable):
|
|||||||
|
|
||||||
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
|
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
|
||||||
|
# Use `Variable` for type hinting when serialization is not required.
|
||||||
|
#
|
||||||
|
# Note:
|
||||||
|
# - All variants in `VariableUnion` must inherit from the `Variable` class.
|
||||||
|
# - The union must include all non-abstract subclasses of `Segment`, except:
|
||||||
|
VariableUnion: TypeAlias = Annotated[
|
||||||
|
(
|
||||||
|
Annotated[NoneVariable, Tag(SegmentType.NONE)]
|
||||||
|
| Annotated[StringVariable, Tag(SegmentType.STRING)]
|
||||||
|
| Annotated[FloatVariable, Tag(SegmentType.FLOAT)]
|
||||||
|
| Annotated[IntegerVariable, Tag(SegmentType.INTEGER)]
|
||||||
|
| Annotated[ObjectVariable, Tag(SegmentType.OBJECT)]
|
||||||
|
| Annotated[FileVariable, Tag(SegmentType.FILE)]
|
||||||
|
| Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)]
|
||||||
|
| Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)]
|
||||||
|
| Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)]
|
||||||
|
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
|
||||||
|
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
|
||||||
|
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
|
||||||
|
),
|
||||||
|
Discriminator(get_segment_discriminator),
|
||||||
|
]
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, Union
|
from typing import Annotated, Any, Union, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -9,8 +9,9 @@ from core.file import File, FileAttribute, file_manager
|
|||||||
from core.variables import Segment, SegmentGroup, Variable
|
from core.variables import Segment, SegmentGroup, Variable
|
||||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||||
from core.variables.segments import FileSegment, NoneSegment
|
from core.variables.segments import FileSegment, NoneSegment
|
||||||
|
from core.variables.variables import VariableUnion
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.system_variable import SystemVariable
|
||||||
from factories import variable_factory
|
from factories import variable_factory
|
||||||
|
|
||||||
VariableValue = Union[str, int, float, dict, list, File]
|
VariableValue = Union[str, int, float, dict, list, File]
|
||||||
@ -23,31 +24,31 @@ class VariablePool(BaseModel):
|
|||||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||||
# elements of the selector except the first one.
|
# elements of the selector except the first one.
|
||||||
variable_dictionary: dict[str, dict[int, Segment]] = Field(
|
variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field(
|
||||||
description="Variables mapping",
|
description="Variables mapping",
|
||||||
default=defaultdict(dict),
|
default=defaultdict(dict),
|
||||||
)
|
)
|
||||||
# TODO: This user inputs is not used for pool.
|
|
||||||
|
# The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere.
|
||||||
user_inputs: Mapping[str, Any] = Field(
|
user_inputs: Mapping[str, Any] = Field(
|
||||||
description="User inputs",
|
description="User inputs",
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
system_variables: Mapping[SystemVariableKey, Any] = Field(
|
system_variables: SystemVariable = Field(
|
||||||
description="System variables",
|
description="System variables",
|
||||||
default_factory=dict,
|
|
||||||
)
|
)
|
||||||
environment_variables: Sequence[Variable] = Field(
|
environment_variables: Sequence[VariableUnion] = Field(
|
||||||
description="Environment variables.",
|
description="Environment variables.",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
conversation_variables: Sequence[Variable] = Field(
|
conversation_variables: Sequence[VariableUnion] = Field(
|
||||||
description="Conversation variables.",
|
description="Conversation variables.",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
def model_post_init(self, context: Any, /) -> None:
|
def model_post_init(self, context: Any, /) -> None:
|
||||||
for key, value in self.system_variables.items():
|
# Create a mapping from field names to SystemVariableKey enum values
|
||||||
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
|
self._add_system_variables(self.system_variables)
|
||||||
# Add environment variables to the variable pool
|
# Add environment variables to the variable pool
|
||||||
for var in self.environment_variables:
|
for var in self.environment_variables:
|
||||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||||
@ -83,8 +84,22 @@ class VariablePool(BaseModel):
|
|||||||
segment = variable_factory.build_segment(value)
|
segment = variable_factory.build_segment(value)
|
||||||
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
||||||
|
|
||||||
hash_key = hash(tuple(selector[1:]))
|
key, hash_key = self._selector_to_keys(selector)
|
||||||
self.variable_dictionary[selector[0]][hash_key] = variable
|
# Based on the definition of `VariableUnion`,
|
||||||
|
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||||
|
self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]:
|
||||||
|
return selector[0], hash(tuple(selector[1:]))
|
||||||
|
|
||||||
|
def _has(self, selector: Sequence[str]) -> bool:
|
||||||
|
key, hash_key = self._selector_to_keys(selector)
|
||||||
|
if key not in self.variable_dictionary:
|
||||||
|
return False
|
||||||
|
if hash_key not in self.variable_dictionary[key]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||||
"""
|
"""
|
||||||
@ -102,8 +117,8 @@ class VariablePool(BaseModel):
|
|||||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
hash_key = hash(tuple(selector[1:]))
|
key, hash_key = self._selector_to_keys(selector)
|
||||||
value = self.variable_dictionary[selector[0]].get(hash_key)
|
value: Segment | None = self.variable_dictionary[key].get(hash_key)
|
||||||
|
|
||||||
if value is None:
|
if value is None:
|
||||||
selector, attr = selector[:-1], selector[-1]
|
selector, attr = selector[:-1], selector[-1]
|
||||||
@ -136,8 +151,9 @@ class VariablePool(BaseModel):
|
|||||||
if len(selector) == 1:
|
if len(selector) == 1:
|
||||||
self.variable_dictionary[selector[0]] = {}
|
self.variable_dictionary[selector[0]] = {}
|
||||||
return
|
return
|
||||||
|
key, hash_key = self._selector_to_keys(selector)
|
||||||
hash_key = hash(tuple(selector[1:]))
|
hash_key = hash(tuple(selector[1:]))
|
||||||
self.variable_dictionary[selector[0]].pop(hash_key, None)
|
self.variable_dictionary[key].pop(hash_key, None)
|
||||||
|
|
||||||
def convert_template(self, template: str, /):
|
def convert_template(self, template: str, /):
|
||||||
parts = VARIABLE_PATTERN.split(template)
|
parts = VARIABLE_PATTERN.split(template)
|
||||||
@ -154,3 +170,20 @@ class VariablePool(BaseModel):
|
|||||||
if isinstance(segment, FileSegment):
|
if isinstance(segment, FileSegment):
|
||||||
return segment
|
return segment
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _add_system_variables(self, system_variable: SystemVariable):
|
||||||
|
sys_var_mapping = system_variable.to_dict()
|
||||||
|
for key, value in sys_var_mapping.items():
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
selector = (SYSTEM_VARIABLE_NODE_ID, key)
|
||||||
|
# If the system variable already exists, do not add it again.
|
||||||
|
# This ensures that we can keep the id of the system variables intact.
|
||||||
|
if self._has(selector):
|
||||||
|
continue
|
||||||
|
self.add(selector, value) # type: ignore
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty(cls) -> "VariablePool":
|
||||||
|
"""Create an empty variable pool."""
|
||||||
|
return cls(system_variables=SystemVariable.empty())
|
||||||
|
|||||||
@ -17,8 +17,12 @@ class GraphRuntimeState(BaseModel):
|
|||||||
"""total tokens"""
|
"""total tokens"""
|
||||||
llm_usage: LLMUsage = LLMUsage.empty_usage()
|
llm_usage: LLMUsage = LLMUsage.empty_usage()
|
||||||
"""llm usage info"""
|
"""llm usage info"""
|
||||||
|
|
||||||
|
# The `outputs` field stores the final output values generated by executing workflows or chatflows.
|
||||||
|
#
|
||||||
|
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
|
||||||
|
# after a serialization and deserialization round trip.
|
||||||
outputs: dict[str, Any] = {}
|
outputs: dict[str, Any] = {}
|
||||||
"""outputs"""
|
|
||||||
|
|
||||||
node_run_steps: int = 0
|
node_run_steps: int = 0
|
||||||
"""node run steps"""
|
"""node run steps"""
|
||||||
|
|||||||
@ -1,11 +1,29 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Literal, Optional
|
from typing import Annotated, Any, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import AfterValidator, BaseModel, Field
|
||||||
|
|
||||||
|
from core.variables.types import SegmentType
|
||||||
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||||
from core.workflow.utils.condition.entities import Condition
|
from core.workflow.utils.condition.entities import Condition
|
||||||
|
|
||||||
|
_VALID_VAR_TYPE = frozenset(
|
||||||
|
[
|
||||||
|
SegmentType.STRING,
|
||||||
|
SegmentType.NUMBER,
|
||||||
|
SegmentType.OBJECT,
|
||||||
|
SegmentType.ARRAY_STRING,
|
||||||
|
SegmentType.ARRAY_NUMBER,
|
||||||
|
SegmentType.ARRAY_OBJECT,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
|
||||||
|
if seg_type not in _VALID_VAR_TYPE:
|
||||||
|
raise ValueError(...)
|
||||||
|
return seg_type
|
||||||
|
|
||||||
|
|
||||||
class LoopVariableData(BaseModel):
|
class LoopVariableData(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -13,7 +31,7 @@ class LoopVariableData(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
label: str
|
label: str
|
||||||
var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
|
||||||
value_type: Literal["variable", "constant"]
|
value_type: Literal["variable", "constant"]
|
||||||
value: Optional[Any | list[str]] = None
|
value: Optional[Any | list[str]] = None
|
||||||
|
|
||||||
|
|||||||
@ -7,14 +7,9 @@ from typing import TYPE_CHECKING, Any, Literal, cast
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.variables import (
|
from core.variables import (
|
||||||
ArrayNumberSegment,
|
|
||||||
ArrayObjectSegment,
|
|
||||||
ArrayStringSegment,
|
|
||||||
IntegerSegment,
|
IntegerSegment,
|
||||||
ObjectSegment,
|
|
||||||
Segment,
|
Segment,
|
||||||
SegmentType,
|
SegmentType,
|
||||||
StringSegment,
|
|
||||||
)
|
)
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
@ -39,6 +34,7 @@ from core.workflow.nodes.enums import NodeType
|
|||||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||||
from core.workflow.nodes.loop.entities import LoopNodeData
|
from core.workflow.nodes.loop.entities import LoopNodeData
|
||||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||||
|
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
@ -505,23 +501,21 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||||||
return variable_mapping
|
return variable_mapping
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_segment_for_constant(var_type: str, value: Any) -> Segment:
|
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
|
||||||
"""Get the appropriate segment type for a constant value."""
|
"""Get the appropriate segment type for a constant value."""
|
||||||
segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = {
|
|
||||||
"string": (StringSegment, SegmentType.STRING),
|
|
||||||
"number": (IntegerSegment, SegmentType.NUMBER),
|
|
||||||
"object": (ObjectSegment, SegmentType.OBJECT),
|
|
||||||
"array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING),
|
|
||||||
"array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER),
|
|
||||||
"array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT),
|
|
||||||
}
|
|
||||||
if var_type in ["array[string]", "array[number]", "array[object]"]:
|
if var_type in ["array[string]", "array[number]", "array[object]"]:
|
||||||
if value:
|
if value and isinstance(value, str):
|
||||||
value = json.loads(value)
|
value = json.loads(value)
|
||||||
else:
|
else:
|
||||||
value = []
|
value = []
|
||||||
segment_info = segment_mapping.get(var_type)
|
try:
|
||||||
if not segment_info:
|
return build_segment_with_type(var_type, value)
|
||||||
raise ValueError(f"Invalid variable type: {var_type}")
|
except TypeMismatchError as type_exc:
|
||||||
segment_class, value_type = segment_info
|
# Attempt to parse the value as a JSON-encoded string, if applicable.
|
||||||
return segment_class(value=value, value_type=value_type)
|
if not isinstance(value, str):
|
||||||
|
raise
|
||||||
|
try:
|
||||||
|
value = json.loads(value)
|
||||||
|
except ValueError:
|
||||||
|
raise type_exc
|
||||||
|
return build_segment_with_type(var_type, value)
|
||||||
|
|||||||
@ -16,7 +16,7 @@ class StartNode(BaseNode[StartNodeData]):
|
|||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables
|
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
|
||||||
|
|
||||||
# TODO: System variables should be directly accessible, no need for special handling
|
# TODO: System variables should be directly accessible, no need for special handling
|
||||||
# Set system variables as node outputs.
|
# Set system variables as node outputs.
|
||||||
|
|||||||
@ -130,6 +130,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
|||||||
|
|
||||||
|
|
||||||
def get_zero_value(t: SegmentType):
|
def get_zero_value(t: SegmentType):
|
||||||
|
# TODO(QuantumGhost): this should be a method of `SegmentType`.
|
||||||
match t:
|
match t:
|
||||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||||
return variable_factory.build_segment([])
|
return variable_factory.build_segment([])
|
||||||
@ -137,6 +138,10 @@ def get_zero_value(t: SegmentType):
|
|||||||
return variable_factory.build_segment({})
|
return variable_factory.build_segment({})
|
||||||
case SegmentType.STRING:
|
case SegmentType.STRING:
|
||||||
return variable_factory.build_segment("")
|
return variable_factory.build_segment("")
|
||||||
|
case SegmentType.INTEGER:
|
||||||
|
return variable_factory.build_segment(0)
|
||||||
|
case SegmentType.FLOAT:
|
||||||
|
return variable_factory.build_segment(0.0)
|
||||||
case SegmentType.NUMBER:
|
case SegmentType.NUMBER:
|
||||||
return variable_factory.build_segment(0)
|
return variable_factory.build_segment(0)
|
||||||
case _:
|
case _:
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from core.variables import SegmentType
|
from core.variables import SegmentType
|
||||||
|
|
||||||
|
# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy.
|
||||||
EMPTY_VALUE_MAPPING = {
|
EMPTY_VALUE_MAPPING = {
|
||||||
SegmentType.STRING: "",
|
SegmentType.STRING: "",
|
||||||
SegmentType.NUMBER: 0,
|
SegmentType.NUMBER: 0,
|
||||||
|
|||||||
@ -10,10 +10,16 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
|
|||||||
case Operation.OVER_WRITE | Operation.CLEAR:
|
case Operation.OVER_WRITE | Operation.CLEAR:
|
||||||
return True
|
return True
|
||||||
case Operation.SET:
|
case Operation.SET:
|
||||||
return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER}
|
return variable_type in {
|
||||||
|
SegmentType.OBJECT,
|
||||||
|
SegmentType.STRING,
|
||||||
|
SegmentType.NUMBER,
|
||||||
|
SegmentType.INTEGER,
|
||||||
|
SegmentType.FLOAT,
|
||||||
|
}
|
||||||
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
|
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
|
||||||
# Only number variable can be added, subtracted, multiplied or divided
|
# Only number variable can be added, subtracted, multiplied or divided
|
||||||
return variable_type == SegmentType.NUMBER
|
return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
|
||||||
case Operation.APPEND | Operation.EXTEND:
|
case Operation.APPEND | Operation.EXTEND:
|
||||||
# Only array variable can be appended or extended
|
# Only array variable can be appended or extended
|
||||||
return variable_type in {
|
return variable_type in {
|
||||||
@ -46,7 +52,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat
|
|||||||
match variable_type:
|
match variable_type:
|
||||||
case SegmentType.STRING | SegmentType.OBJECT:
|
case SegmentType.STRING | SegmentType.OBJECT:
|
||||||
return operation in {Operation.OVER_WRITE, Operation.SET}
|
return operation in {Operation.OVER_WRITE, Operation.SET}
|
||||||
case SegmentType.NUMBER:
|
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||||
return operation in {
|
return operation in {
|
||||||
Operation.OVER_WRITE,
|
Operation.OVER_WRITE,
|
||||||
Operation.SET,
|
Operation.SET,
|
||||||
@ -66,7 +72,7 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
|||||||
case SegmentType.STRING:
|
case SegmentType.STRING:
|
||||||
return isinstance(value, str)
|
return isinstance(value, str)
|
||||||
|
|
||||||
case SegmentType.NUMBER:
|
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||||
if not isinstance(value, int | float):
|
if not isinstance(value, int | float):
|
||||||
return False
|
return False
|
||||||
if operation == Operation.DIVIDE and value == 0:
|
if operation == Operation.DIVIDE and value == 0:
|
||||||
|
|||||||
89
api/core/workflow/system_variable.py
Normal file
89
api/core/workflow/system_variable.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
|
from core.file.models import File
|
||||||
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
|
||||||
|
|
||||||
|
class SystemVariable(BaseModel):
|
||||||
|
"""A model for managing system variables.
|
||||||
|
|
||||||
|
Fields with a value of `None` are treated as absent and will not be included
|
||||||
|
in the variable pool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
extra="forbid",
|
||||||
|
serialize_by_alias=True,
|
||||||
|
validate_by_alias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id: str | None = None
|
||||||
|
|
||||||
|
# Ideally, `app_id` and `workflow_id` should be required and not `None`.
|
||||||
|
# However, there are scenarios in the codebase where these fields are not set.
|
||||||
|
# To maintain compatibility, they are marked as optional here.
|
||||||
|
app_id: str | None = None
|
||||||
|
workflow_id: str | None = None
|
||||||
|
|
||||||
|
files: Sequence[File] = Field(default_factory=list)
|
||||||
|
|
||||||
|
# NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`.
|
||||||
|
# To maintain compatibility with existing workflows, it must be serialized
|
||||||
|
# as `workflow_run_id` in dictionaries or JSON objects, and also referenced
|
||||||
|
# as `workflow_run_id` in the variable pool.
|
||||||
|
workflow_execution_id: str | None = Field(
|
||||||
|
validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"),
|
||||||
|
serialization_alias="workflow_run_id",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
# Chatflow related fields.
|
||||||
|
query: str | None = None
|
||||||
|
conversation_id: str | None = None
|
||||||
|
dialogue_count: int | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_json_fields(cls, data):
|
||||||
|
if isinstance(data, dict):
|
||||||
|
# For JSON validation, only allow workflow_run_id
|
||||||
|
if "workflow_execution_id" in data and "workflow_run_id" not in data:
|
||||||
|
# This is likely from direct instantiation, allow it
|
||||||
|
return data
|
||||||
|
elif "workflow_execution_id" in data and "workflow_run_id" in data:
|
||||||
|
# Both present, remove workflow_execution_id
|
||||||
|
data = data.copy()
|
||||||
|
data.pop("workflow_execution_id")
|
||||||
|
return data
|
||||||
|
return data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty(cls) -> "SystemVariable":
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[SystemVariableKey, Any]:
|
||||||
|
# NOTE: This method is provided for compatibility with legacy code.
|
||||||
|
# New code should use the `SystemVariable` object directly instead of converting
|
||||||
|
# it to a dictionary, as this conversion results in the loss of type information
|
||||||
|
# for each key, making static analysis more difficult.
|
||||||
|
|
||||||
|
d: dict[SystemVariableKey, Any] = {
|
||||||
|
SystemVariableKey.FILES: self.files,
|
||||||
|
}
|
||||||
|
if self.user_id is not None:
|
||||||
|
d[SystemVariableKey.USER_ID] = self.user_id
|
||||||
|
if self.app_id is not None:
|
||||||
|
d[SystemVariableKey.APP_ID] = self.app_id
|
||||||
|
if self.workflow_id is not None:
|
||||||
|
d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id
|
||||||
|
if self.workflow_execution_id is not None:
|
||||||
|
d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id
|
||||||
|
if self.query is not None:
|
||||||
|
d[SystemVariableKey.QUERY] = self.query
|
||||||
|
if self.conversation_id is not None:
|
||||||
|
d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id
|
||||||
|
if self.dialogue_count is not None:
|
||||||
|
d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count
|
||||||
|
return d
|
||||||
@ -26,6 +26,7 @@ from core.workflow.entities.workflow_node_execution import (
|
|||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
|
||||||
@ -43,7 +44,7 @@ class WorkflowCycleManager:
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||||
workflow_system_variables: dict[SystemVariableKey, Any],
|
workflow_system_variables: SystemVariable,
|
||||||
workflow_info: CycleManagerWorkflowInfo,
|
workflow_info: CycleManagerWorkflowInfo,
|
||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
@ -56,17 +57,22 @@ class WorkflowCycleManager:
|
|||||||
|
|
||||||
def handle_workflow_run_start(self) -> WorkflowExecution:
|
def handle_workflow_run_start(self) -> WorkflowExecution:
|
||||||
inputs = {**self._application_generate_entity.inputs}
|
inputs = {**self._application_generate_entity.inputs}
|
||||||
for key, value in (self._workflow_system_variables or {}).items():
|
|
||||||
if key.value == "conversation":
|
# Iterate over SystemVariable fields using Pydantic's model_fields
|
||||||
continue
|
if self._workflow_system_variables:
|
||||||
inputs[f"sys.{key.value}"] = value
|
for field_name, value in self._workflow_system_variables.to_dict().items():
|
||||||
|
if field_name == SystemVariableKey.CONVERSATION_ID:
|
||||||
|
continue
|
||||||
|
inputs[f"sys.{field_name}"] = value
|
||||||
|
|
||||||
# handle special values
|
# handle special values
|
||||||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||||
|
|
||||||
# init workflow run
|
# init workflow run
|
||||||
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
|
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
|
||||||
execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_EXECUTION_ID) or uuid4())
|
execution_id = str(
|
||||||
|
self._workflow_system_variables.workflow_execution_id if self._workflow_system_variables else None
|
||||||
|
) or str(uuid4())
|
||||||
execution = WorkflowExecution.new(
|
execution = WorkflowExecution.new(
|
||||||
id_=execution_id,
|
id_=execution_id,
|
||||||
workflow_id=self._workflow_info.workflow_id,
|
workflow_id=self._workflow_info.workflow_id,
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType
|
|||||||
from core.workflow.nodes.base import BaseNode
|
from core.workflow.nodes.base import BaseNode
|
||||||
from core.workflow.nodes.event import NodeEvent
|
from core.workflow.nodes.event import NodeEvent
|
||||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
@ -254,7 +255,7 @@ class WorkflowEntry:
|
|||||||
|
|
||||||
# init variable pool
|
# init variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={},
|
system_variables=SystemVariable.empty(),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -91,9 +91,13 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
|||||||
result = StringVariable.model_validate(mapping)
|
result = StringVariable.model_validate(mapping)
|
||||||
case SegmentType.SECRET:
|
case SegmentType.SECRET:
|
||||||
result = SecretVariable.model_validate(mapping)
|
result = SecretVariable.model_validate(mapping)
|
||||||
case SegmentType.NUMBER if isinstance(value, int):
|
case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int):
|
||||||
|
mapping = dict(mapping)
|
||||||
|
mapping["value_type"] = SegmentType.INTEGER
|
||||||
result = IntegerVariable.model_validate(mapping)
|
result = IntegerVariable.model_validate(mapping)
|
||||||
case SegmentType.NUMBER if isinstance(value, float):
|
case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float):
|
||||||
|
mapping = dict(mapping)
|
||||||
|
mapping["value_type"] = SegmentType.FLOAT
|
||||||
result = FloatVariable.model_validate(mapping)
|
result = FloatVariable.model_validate(mapping)
|
||||||
case SegmentType.NUMBER if not isinstance(value, float | int):
|
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||||
raise VariableError(f"invalid number value {value}")
|
raise VariableError(f"invalid number value {value}")
|
||||||
@ -119,6 +123,8 @@ def infer_segment_type_from_value(value: Any, /) -> SegmentType:
|
|||||||
|
|
||||||
|
|
||||||
def build_segment(value: Any, /) -> Segment:
|
def build_segment(value: Any, /) -> Segment:
|
||||||
|
# NOTE: If you have runtime type information available, consider using the `build_segment_with_type`
|
||||||
|
# below
|
||||||
if value is None:
|
if value is None:
|
||||||
return NoneSegment()
|
return NoneSegment()
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
@ -134,12 +140,17 @@ def build_segment(value: Any, /) -> Segment:
|
|||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
items = [build_segment(item) for item in value]
|
items = [build_segment(item) for item in value]
|
||||||
types = {item.value_type for item in items}
|
types = {item.value_type for item in items}
|
||||||
if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items):
|
if all(isinstance(item, ArraySegment) for item in items):
|
||||||
return ArrayAnySegment(value=value)
|
return ArrayAnySegment(value=value)
|
||||||
|
elif len(types) != 1:
|
||||||
|
if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}):
|
||||||
|
return ArrayNumberSegment(value=value)
|
||||||
|
return ArrayAnySegment(value=value)
|
||||||
|
|
||||||
match types.pop():
|
match types.pop():
|
||||||
case SegmentType.STRING:
|
case SegmentType.STRING:
|
||||||
return ArrayStringSegment(value=value)
|
return ArrayStringSegment(value=value)
|
||||||
case SegmentType.NUMBER:
|
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||||
return ArrayNumberSegment(value=value)
|
return ArrayNumberSegment(value=value)
|
||||||
case SegmentType.OBJECT:
|
case SegmentType.OBJECT:
|
||||||
return ArrayObjectSegment(value=value)
|
return ArrayObjectSegment(value=value)
|
||||||
@ -153,6 +164,22 @@ def build_segment(value: Any, /) -> Segment:
|
|||||||
raise ValueError(f"not supported value {value}")
|
raise ValueError(f"not supported value {value}")
|
||||||
|
|
||||||
|
|
||||||
|
_segment_factory: Mapping[SegmentType, type[Segment]] = {
|
||||||
|
SegmentType.NONE: NoneSegment,
|
||||||
|
SegmentType.STRING: StringSegment,
|
||||||
|
SegmentType.INTEGER: IntegerSegment,
|
||||||
|
SegmentType.FLOAT: FloatSegment,
|
||||||
|
SegmentType.FILE: FileSegment,
|
||||||
|
SegmentType.OBJECT: ObjectSegment,
|
||||||
|
# Array types
|
||||||
|
SegmentType.ARRAY_ANY: ArrayAnySegment,
|
||||||
|
SegmentType.ARRAY_STRING: ArrayStringSegment,
|
||||||
|
SegmentType.ARRAY_NUMBER: ArrayNumberSegment,
|
||||||
|
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
|
||||||
|
SegmentType.ARRAY_FILE: ArrayFileSegment,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
||||||
"""
|
"""
|
||||||
Build a segment with explicit type checking.
|
Build a segment with explicit type checking.
|
||||||
@ -190,7 +217,7 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
|||||||
if segment_type == SegmentType.NONE:
|
if segment_type == SegmentType.NONE:
|
||||||
return NoneSegment()
|
return NoneSegment()
|
||||||
else:
|
else:
|
||||||
raise TypeMismatchError(f"Expected {segment_type}, but got None")
|
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None")
|
||||||
|
|
||||||
# Handle empty list special case for array types
|
# Handle empty list special case for array types
|
||||||
if isinstance(value, list) and len(value) == 0:
|
if isinstance(value, list) and len(value) == 0:
|
||||||
@ -205,21 +232,25 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
|||||||
elif segment_type == SegmentType.ARRAY_FILE:
|
elif segment_type == SegmentType.ARRAY_FILE:
|
||||||
return ArrayFileSegment(value=value)
|
return ArrayFileSegment(value=value)
|
||||||
else:
|
else:
|
||||||
raise TypeMismatchError(f"Expected {segment_type}, but got empty list")
|
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
|
||||||
|
|
||||||
# Build segment using existing logic to infer actual type
|
|
||||||
inferred_segment = build_segment(value)
|
|
||||||
inferred_type = inferred_segment.value_type
|
|
||||||
|
|
||||||
|
inferred_type = SegmentType.infer_segment_type(value)
|
||||||
# Type compatibility checking
|
# Type compatibility checking
|
||||||
|
if inferred_type is None:
|
||||||
|
raise TypeMismatchError(
|
||||||
|
f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}"
|
||||||
|
)
|
||||||
if inferred_type == segment_type:
|
if inferred_type == segment_type:
|
||||||
return inferred_segment
|
segment_class = _segment_factory[segment_type]
|
||||||
|
return segment_class(value_type=segment_type, value=value)
|
||||||
# Type mismatch - raise error with descriptive message
|
elif segment_type == SegmentType.NUMBER and inferred_type in (
|
||||||
raise TypeMismatchError(
|
SegmentType.INTEGER,
|
||||||
f"Type mismatch: expected {segment_type}, but value '{value}' "
|
SegmentType.FLOAT,
|
||||||
f"(type: {type(value).__name__}) corresponds to {inferred_type}"
|
):
|
||||||
)
|
segment_class = _segment_factory[inferred_type]
|
||||||
|
return segment_class(value_type=inferred_type, value=value)
|
||||||
|
else:
|
||||||
|
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
|
||||||
|
|
||||||
|
|
||||||
def segment_to_variable(
|
def segment_to_variable(
|
||||||
@ -247,6 +278,6 @@ def segment_to_variable(
|
|||||||
name=name,
|
name=name,
|
||||||
description=description,
|
description=description,
|
||||||
value=segment.value,
|
value=segment.value,
|
||||||
selector=selector,
|
selector=list(selector),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
15
api/fields/_value_type_serializer.py
Normal file
15
api/fields/_value_type_serializer.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
from core.variables.segments import Segment
|
||||||
|
from core.variables.types import SegmentType
|
||||||
|
|
||||||
|
|
||||||
|
class _VarTypedDict(TypedDict, total=False):
|
||||||
|
value_type: SegmentType
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
|
||||||
|
if isinstance(v, Segment):
|
||||||
|
return v.value_type.exposed_type().value
|
||||||
|
else:
|
||||||
|
return v["value_type"].exposed_type().value
|
||||||
@ -2,10 +2,12 @@ from flask_restful import fields
|
|||||||
|
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
from ._value_type_serializer import serialize_value_type
|
||||||
|
|
||||||
conversation_variable_fields = {
|
conversation_variable_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"name": fields.String,
|
"name": fields.String,
|
||||||
"value_type": fields.String(attribute="value_type.value"),
|
"value_type": fields.String(attribute=serialize_value_type),
|
||||||
"value": fields.String,
|
"value": fields.String,
|
||||||
"description": fields.String,
|
"description": fields.String,
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
|
|||||||
@ -5,6 +5,8 @@ from core.variables import SecretVariable, SegmentType, Variable
|
|||||||
from fields.member_fields import simple_account_fields
|
from fields.member_fields import simple_account_fields
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
from ._value_type_serializer import serialize_value_type
|
||||||
|
|
||||||
ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET)
|
ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET)
|
||||||
|
|
||||||
|
|
||||||
@ -24,11 +26,16 @@ class EnvironmentVariableField(fields.Raw):
|
|||||||
"id": value.id,
|
"id": value.id,
|
||||||
"name": value.name,
|
"name": value.name,
|
||||||
"value": value.value,
|
"value": value.value,
|
||||||
"value_type": value.value_type.value,
|
"value_type": value.value_type.exposed_type().value,
|
||||||
"description": value.description,
|
"description": value.description,
|
||||||
}
|
}
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
value_type = value.get("value_type")
|
value_type_str = value.get("value_type")
|
||||||
|
if not isinstance(value_type_str, str):
|
||||||
|
raise TypeError(
|
||||||
|
f"unexpected type for value_type field, value={value_type_str}, type={type(value_type_str)}"
|
||||||
|
)
|
||||||
|
value_type = SegmentType(value_type_str).exposed_type()
|
||||||
if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES:
|
if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES:
|
||||||
raise ValueError(f"Unsupported environment variable value type: {value_type}")
|
raise ValueError(f"Unsupported environment variable value type: {value_type}")
|
||||||
return value
|
return value
|
||||||
@ -37,7 +44,7 @@ class EnvironmentVariableField(fields.Raw):
|
|||||||
conversation_variable_fields = {
|
conversation_variable_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"name": fields.String,
|
"name": fields.String,
|
||||||
"value_type": fields.String(attribute="value_type.value"),
|
"value_type": fields.String(attribute=serialize_value_type),
|
||||||
"value": fields.Raw,
|
"value": fields.Raw,
|
||||||
"description": fields.String,
|
"description": fields.String,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from sqlalchemy import orm
|
|||||||
from core.file.constants import maybe_file_object
|
from core.file.constants import maybe_file_object
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
from core.variables import utils as variable_utils
|
from core.variables import utils as variable_utils
|
||||||
|
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.nodes.enums import NodeType
|
||||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||||
@ -347,7 +348,7 @@ class Workflow(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def environment_variables(self) -> Sequence[Variable]:
|
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
# TODO: find some way to init `self._environment_variables` when instance created.
|
||||||
if self._environment_variables is None:
|
if self._environment_variables is None:
|
||||||
self._environment_variables = "{}"
|
self._environment_variables = "{}"
|
||||||
@ -367,11 +368,15 @@ class Workflow(Base):
|
|||||||
def decrypt_func(var):
|
def decrypt_func(var):
|
||||||
if isinstance(var, SecretVariable):
|
if isinstance(var, SecretVariable):
|
||||||
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
|
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
|
||||||
else:
|
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
|
||||||
return var
|
return var
|
||||||
|
else:
|
||||||
|
raise AssertionError("this statement should be unreachable.")
|
||||||
|
|
||||||
results = list(map(decrypt_func, results))
|
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list(
|
||||||
return results
|
map(decrypt_func, results)
|
||||||
|
)
|
||||||
|
return decrypted_results
|
||||||
|
|
||||||
@environment_variables.setter
|
@environment_variables.setter
|
||||||
def environment_variables(self, value: Sequence[Variable]):
|
def environment_variables(self, value: Sequence[Variable]):
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@ -15,10 +15,10 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
|||||||
from core.file import File
|
from core.file import File
|
||||||
from core.repositories import DifyCoreRepositoryFactory
|
from core.repositories import DifyCoreRepositoryFactory
|
||||||
from core.variables import Variable
|
from core.variables import Variable
|
||||||
|
from core.variables.variables import VariableUnion
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
@ -28,6 +28,7 @@ from core.workflow.nodes.event import RunCompletedEvent
|
|||||||
from core.workflow.nodes.event.types import NodeEvent
|
from core.workflow.nodes.event.types import NodeEvent
|
||||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||||
from core.workflow.nodes.start.entities import StartNodeData
|
from core.workflow.nodes.start.entities import StartNodeData
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -369,7 +370,7 @@ class WorkflowService:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={},
|
system_variables=SystemVariable.empty(),
|
||||||
user_inputs=user_inputs,
|
user_inputs=user_inputs,
|
||||||
environment_variables=draft_workflow.environment_variables,
|
environment_variables=draft_workflow.environment_variables,
|
||||||
conversation_variables=[],
|
conversation_variables=[],
|
||||||
@ -685,36 +686,30 @@ def _setup_variable_pool(
|
|||||||
):
|
):
|
||||||
# Only inject system variables for START node type.
|
# Only inject system variables for START node type.
|
||||||
if node_type == NodeType.START:
|
if node_type == NodeType.START:
|
||||||
# Create a variable pool.
|
system_variable = SystemVariable(
|
||||||
system_inputs: dict[SystemVariableKey, Any] = {
|
user_id=user_id,
|
||||||
# From inputs:
|
app_id=workflow.app_id,
|
||||||
SystemVariableKey.FILES: files,
|
workflow_id=workflow.id,
|
||||||
SystemVariableKey.USER_ID: user_id,
|
files=files or [],
|
||||||
# From workflow model
|
workflow_execution_id=str(uuid.uuid4()),
|
||||||
SystemVariableKey.APP_ID: workflow.app_id,
|
)
|
||||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
|
||||||
# Randomly generated.
|
|
||||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: str(uuid.uuid4()),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Only add chatflow-specific variables for non-workflow types
|
# Only add chatflow-specific variables for non-workflow types
|
||||||
if workflow.type != WorkflowType.WORKFLOW.value:
|
if workflow.type != WorkflowType.WORKFLOW.value:
|
||||||
system_inputs.update(
|
system_variable.query = query
|
||||||
{
|
system_variable.conversation_id = conversation_id
|
||||||
SystemVariableKey.QUERY: query,
|
system_variable.dialogue_count = 0
|
||||||
SystemVariableKey.CONVERSATION_ID: conversation_id,
|
|
||||||
SystemVariableKey.DIALOGUE_COUNT: 0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
system_inputs = {}
|
system_variable = SystemVariable.empty()
|
||||||
|
|
||||||
# init variable pool
|
# init variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables=system_inputs,
|
system_variables=system_variable,
|
||||||
user_inputs=user_inputs,
|
user_inputs=user_inputs,
|
||||||
environment_variables=workflow.environment_variables,
|
environment_variables=workflow.environment_variables,
|
||||||
conversation_variables=conversation_variables,
|
# Based on the definition of `VariableUnion`,
|
||||||
|
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||||
|
conversation_variables=cast(list[VariableUnion], conversation_variables), #
|
||||||
)
|
)
|
||||||
|
|
||||||
return variable_pool
|
return variable_pool
|
||||||
|
|||||||
@ -9,12 +9,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes.code.code_node import CodeNode
|
from core.workflow.nodes.code.code_node import CodeNode
|
||||||
from core.workflow.nodes.code.entities import CodeNodeData
|
from core.workflow.nodes.code.entities import CodeNodeData
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||||
@ -50,7 +50,7 @@ def init_code_node(code_config: dict):
|
|||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[],
|
conversation_variables=[],
|
||||||
|
|||||||
@ -6,11 +6,11 @@ import pytest
|
|||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
|
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
|
||||||
@ -44,7 +44,7 @@ def init_http_node(config: dict):
|
|||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[],
|
conversation_variables=[],
|
||||||
|
|||||||
@ -13,12 +13,12 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
|||||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes.event import RunCompletedEvent
|
from core.workflow.nodes.event import RunCompletedEvent
|
||||||
from core.workflow.nodes.llm.node import LLMNode
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
@ -62,12 +62,14 @@ def init_llm_node(config: dict) -> LLMNode:
|
|||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={
|
system_variables=SystemVariable(
|
||||||
SystemVariableKey.QUERY: "what's the weather today?",
|
user_id="aaa",
|
||||||
SystemVariableKey.FILES: [],
|
app_id=app_id,
|
||||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
workflow_id=workflow_id,
|
||||||
SystemVariableKey.USER_ID: "aaa",
|
files=[],
|
||||||
},
|
query="what's the weather today?",
|
||||||
|
conversation_id="abababa",
|
||||||
|
),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[],
|
conversation_variables=[],
|
||||||
|
|||||||
@ -8,11 +8,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||||||
from core.model_runtime.entities import AssistantPromptMessage
|
from core.model_runtime.entities import AssistantPromptMessage
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
|
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
|
||||||
@ -64,12 +64,9 @@ def init_parameter_extractor_node(config: dict):
|
|||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={
|
system_variables=SystemVariable(
|
||||||
SystemVariableKey.QUERY: "what's the weather in SF",
|
user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa"
|
||||||
SystemVariableKey.FILES: [],
|
),
|
||||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
|
||||||
SystemVariableKey.USER_ID: "aaa",
|
|
||||||
},
|
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[],
|
conversation_variables=[],
|
||||||
|
|||||||
@ -6,11 +6,11 @@ import pytest
|
|||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||||
@ -61,7 +61,7 @@ def test_execute_code(setup_code_executor_mock):
|
|||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[],
|
conversation_variables=[],
|
||||||
|
|||||||
@ -6,12 +6,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes.event.event import RunCompletedEvent
|
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
@ -44,7 +44,7 @@ def init_tool_node(config: dict):
|
|||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[],
|
conversation_variables=[],
|
||||||
|
|||||||
@ -1,14 +1,49 @@
|
|||||||
|
import dataclasses
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.file import File, FileTransferMethod, FileType
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.variables import SecretVariable, StringVariable
|
from core.variables.segments import (
|
||||||
|
ArrayAnySegment,
|
||||||
|
ArrayFileSegment,
|
||||||
|
ArrayNumberSegment,
|
||||||
|
ArrayObjectSegment,
|
||||||
|
ArrayStringSegment,
|
||||||
|
FileSegment,
|
||||||
|
FloatSegment,
|
||||||
|
IntegerSegment,
|
||||||
|
NoneSegment,
|
||||||
|
ObjectSegment,
|
||||||
|
Segment,
|
||||||
|
SegmentUnion,
|
||||||
|
StringSegment,
|
||||||
|
get_segment_discriminator,
|
||||||
|
)
|
||||||
|
from core.variables.types import SegmentType
|
||||||
|
from core.variables.variables import (
|
||||||
|
ArrayAnyVariable,
|
||||||
|
ArrayFileVariable,
|
||||||
|
ArrayNumberVariable,
|
||||||
|
ArrayObjectVariable,
|
||||||
|
ArrayStringVariable,
|
||||||
|
FileVariable,
|
||||||
|
FloatVariable,
|
||||||
|
IntegerVariable,
|
||||||
|
NoneVariable,
|
||||||
|
ObjectVariable,
|
||||||
|
SecretVariable,
|
||||||
|
StringVariable,
|
||||||
|
Variable,
|
||||||
|
VariableUnion,
|
||||||
|
)
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.system_variable import SystemVariable
|
||||||
|
|
||||||
|
|
||||||
def test_segment_group_to_text():
|
def test_segment_group_to_text():
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={
|
system_variables=SystemVariable(user_id="fake-user-id"),
|
||||||
SystemVariableKey("user_id"): "fake-user-id",
|
|
||||||
},
|
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[
|
environment_variables=[
|
||||||
SecretVariable(name="secret_key", value="fake-secret-key"),
|
SecretVariable(name="secret_key", value="fake-secret-key"),
|
||||||
@ -30,7 +65,7 @@ def test_segment_group_to_text():
|
|||||||
|
|
||||||
def test_convert_constant_to_segment_group():
|
def test_convert_constant_to_segment_group():
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={},
|
system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[],
|
conversation_variables=[],
|
||||||
@ -43,9 +78,7 @@ def test_convert_constant_to_segment_group():
|
|||||||
|
|
||||||
def test_convert_variable_to_segment_group():
|
def test_convert_variable_to_segment_group():
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={
|
system_variables=SystemVariable(user_id="fake-user-id"),
|
||||||
SystemVariableKey("user_id"): "fake-user-id",
|
|
||||||
},
|
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[],
|
conversation_variables=[],
|
||||||
@ -56,3 +89,297 @@ def test_convert_variable_to_segment_group():
|
|||||||
assert segments_group.log == "fake-user-id"
|
assert segments_group.log == "fake-user-id"
|
||||||
assert isinstance(segments_group.value[0], StringVariable)
|
assert isinstance(segments_group.value[0], StringVariable)
|
||||||
assert segments_group.value[0].value == "fake-user-id"
|
assert segments_group.value[0].value == "fake-user-id"
|
||||||
|
|
||||||
|
|
||||||
|
class _Segments(BaseModel):
|
||||||
|
segments: list[SegmentUnion]
|
||||||
|
|
||||||
|
|
||||||
|
class _Variables(BaseModel):
|
||||||
|
variables: list[VariableUnion]
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_file(
|
||||||
|
file_type: FileType = FileType.DOCUMENT,
|
||||||
|
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
|
||||||
|
filename: str = "test.txt",
|
||||||
|
extension: str = ".txt",
|
||||||
|
mime_type: str = "text/plain",
|
||||||
|
size: int = 1024,
|
||||||
|
) -> File:
|
||||||
|
"""Factory function to create File objects for testing"""
|
||||||
|
return File(
|
||||||
|
tenant_id="test-tenant",
|
||||||
|
type=file_type,
|
||||||
|
transfer_method=transfer_method,
|
||||||
|
filename=filename,
|
||||||
|
extension=extension,
|
||||||
|
mime_type=mime_type,
|
||||||
|
size=size,
|
||||||
|
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
|
||||||
|
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||||
|
storage_key="test-storage-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSegmentDumpAndLoad:
|
||||||
|
"""Test suite for segment and variable serialization/deserialization"""
|
||||||
|
|
||||||
|
def test_segments(self):
|
||||||
|
"""Test basic segment serialization compatibility"""
|
||||||
|
model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")])
|
||||||
|
json = model.model_dump_json()
|
||||||
|
print("Json: ", json)
|
||||||
|
loaded = _Segments.model_validate_json(json)
|
||||||
|
assert loaded == model
|
||||||
|
|
||||||
|
def test_segment_number(self):
|
||||||
|
"""Test number segment serialization compatibility"""
|
||||||
|
model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)])
|
||||||
|
json = model.model_dump_json()
|
||||||
|
print("Json: ", json)
|
||||||
|
loaded = _Segments.model_validate_json(json)
|
||||||
|
assert loaded == model
|
||||||
|
|
||||||
|
def test_variables(self):
|
||||||
|
"""Test variable serialization compatibility"""
|
||||||
|
model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")])
|
||||||
|
json = model.model_dump_json()
|
||||||
|
print("Json: ", json)
|
||||||
|
restored = _Variables.model_validate_json(json)
|
||||||
|
assert restored == model
|
||||||
|
|
||||||
|
def test_all_segments_serialization(self):
|
||||||
|
"""Test serialization/deserialization of all segment types"""
|
||||||
|
# Create one instance of each segment type
|
||||||
|
test_file = create_test_file()
|
||||||
|
|
||||||
|
all_segments: list[SegmentUnion] = [
|
||||||
|
NoneSegment(),
|
||||||
|
StringSegment(value="test string"),
|
||||||
|
IntegerSegment(value=42),
|
||||||
|
FloatSegment(value=3.14),
|
||||||
|
ObjectSegment(value={"key": "value", "number": 123}),
|
||||||
|
FileSegment(value=test_file),
|
||||||
|
ArrayAnySegment(value=[1, "string", 3.14, {"key": "value"}]),
|
||||||
|
ArrayStringSegment(value=["hello", "world"]),
|
||||||
|
ArrayNumberSegment(value=[1, 2.5, 3]),
|
||||||
|
ArrayObjectSegment(value=[{"id": 1}, {"id": 2}]),
|
||||||
|
ArrayFileSegment(value=[]), # Empty array to avoid file complexity
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test serialization and deserialization
|
||||||
|
model = _Segments(segments=all_segments)
|
||||||
|
json_str = model.model_dump_json()
|
||||||
|
loaded = _Segments.model_validate_json(json_str)
|
||||||
|
|
||||||
|
# Verify all segments are preserved
|
||||||
|
assert len(loaded.segments) == len(all_segments)
|
||||||
|
|
||||||
|
for original, loaded_segment in zip(all_segments, loaded.segments):
|
||||||
|
assert type(loaded_segment) == type(original)
|
||||||
|
assert loaded_segment.value_type == original.value_type
|
||||||
|
|
||||||
|
# For file segments, compare key properties instead of exact equality
|
||||||
|
if isinstance(original, FileSegment) and isinstance(loaded_segment, FileSegment):
|
||||||
|
orig_file = original.value
|
||||||
|
loaded_file = loaded_segment.value
|
||||||
|
assert isinstance(orig_file, File)
|
||||||
|
assert isinstance(loaded_file, File)
|
||||||
|
assert loaded_file.tenant_id == orig_file.tenant_id
|
||||||
|
assert loaded_file.type == orig_file.type
|
||||||
|
assert loaded_file.filename == orig_file.filename
|
||||||
|
else:
|
||||||
|
assert loaded_segment.value == original.value
|
||||||
|
|
||||||
|
def test_all_variables_serialization(self):
|
||||||
|
"""Test serialization/deserialization of all variable types"""
|
||||||
|
# Create one instance of each variable type
|
||||||
|
test_file = create_test_file()
|
||||||
|
|
||||||
|
all_variables: list[VariableUnion] = [
|
||||||
|
NoneVariable(name="none_var"),
|
||||||
|
StringVariable(value="test string", name="string_var"),
|
||||||
|
IntegerVariable(value=42, name="int_var"),
|
||||||
|
FloatVariable(value=3.14, name="float_var"),
|
||||||
|
ObjectVariable(value={"key": "value", "number": 123}, name="object_var"),
|
||||||
|
FileVariable(value=test_file, name="file_var"),
|
||||||
|
ArrayAnyVariable(value=[1, "string", 3.14, {"key": "value"}], name="array_any_var"),
|
||||||
|
ArrayStringVariable(value=["hello", "world"], name="array_string_var"),
|
||||||
|
ArrayNumberVariable(value=[1, 2.5, 3], name="array_number_var"),
|
||||||
|
ArrayObjectVariable(value=[{"id": 1}, {"id": 2}], name="array_object_var"),
|
||||||
|
ArrayFileVariable(value=[], name="array_file_var"), # Empty array to avoid file complexity
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test serialization and deserialization
|
||||||
|
model = _Variables(variables=all_variables)
|
||||||
|
json_str = model.model_dump_json()
|
||||||
|
loaded = _Variables.model_validate_json(json_str)
|
||||||
|
|
||||||
|
# Verify all variables are preserved
|
||||||
|
assert len(loaded.variables) == len(all_variables)
|
||||||
|
|
||||||
|
for original, loaded_variable in zip(all_variables, loaded.variables):
|
||||||
|
assert type(loaded_variable) == type(original)
|
||||||
|
assert loaded_variable.value_type == original.value_type
|
||||||
|
assert loaded_variable.name == original.name
|
||||||
|
|
||||||
|
# For file variables, compare key properties instead of exact equality
|
||||||
|
if isinstance(original, FileVariable) and isinstance(loaded_variable, FileVariable):
|
||||||
|
orig_file = original.value
|
||||||
|
loaded_file = loaded_variable.value
|
||||||
|
assert isinstance(orig_file, File)
|
||||||
|
assert isinstance(loaded_file, File)
|
||||||
|
assert loaded_file.tenant_id == orig_file.tenant_id
|
||||||
|
assert loaded_file.type == orig_file.type
|
||||||
|
assert loaded_file.filename == orig_file.filename
|
||||||
|
else:
|
||||||
|
assert loaded_variable.value == original.value
|
||||||
|
|
||||||
|
def test_segment_discriminator_function_for_segment_types(self):
|
||||||
|
"""Test the segment discriminator function"""
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class TestCase:
|
||||||
|
segment: Segment
|
||||||
|
expected_segment_type: SegmentType
|
||||||
|
|
||||||
|
file1 = create_test_file()
|
||||||
|
file2 = create_test_file(filename="test2.txt")
|
||||||
|
|
||||||
|
cases = [
|
||||||
|
TestCase(
|
||||||
|
NoneSegment(),
|
||||||
|
SegmentType.NONE,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
StringSegment(value=""),
|
||||||
|
SegmentType.STRING,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
FloatSegment(value=0.0),
|
||||||
|
SegmentType.FLOAT,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
IntegerSegment(value=0),
|
||||||
|
SegmentType.INTEGER,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
ObjectSegment(value={}),
|
||||||
|
SegmentType.OBJECT,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
FileSegment(value=file1),
|
||||||
|
SegmentType.FILE,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
ArrayAnySegment(value=[0, 0.0, ""]),
|
||||||
|
SegmentType.ARRAY_ANY,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
ArrayStringSegment(value=[""]),
|
||||||
|
SegmentType.ARRAY_STRING,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
ArrayNumberSegment(value=[0, 0.0]),
|
||||||
|
SegmentType.ARRAY_NUMBER,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
ArrayObjectSegment(value=[{}]),
|
||||||
|
SegmentType.ARRAY_OBJECT,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
ArrayFileSegment(value=[file1, file2]),
|
||||||
|
SegmentType.ARRAY_FILE,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for test_case in cases:
|
||||||
|
segment = test_case.segment
|
||||||
|
assert get_segment_discriminator(segment) == test_case.expected_segment_type, (
|
||||||
|
f"get_segment_discriminator failed for type {type(segment)}"
|
||||||
|
)
|
||||||
|
model_dict = segment.model_dump(mode="json")
|
||||||
|
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
|
||||||
|
f"get_segment_discriminator failed for serialized form of type {type(segment)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_variable_discriminator_function_for_variable_types(self):
|
||||||
|
"""Test the variable discriminator function"""
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class TestCase:
|
||||||
|
variable: Variable
|
||||||
|
expected_segment_type: SegmentType
|
||||||
|
|
||||||
|
file1 = create_test_file()
|
||||||
|
file2 = create_test_file(filename="test2.txt")
|
||||||
|
|
||||||
|
cases = [
|
||||||
|
TestCase(
|
||||||
|
NoneVariable(name="none_var"),
|
||||||
|
SegmentType.NONE,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
StringVariable(value="test", name="string_var"),
|
||||||
|
SegmentType.STRING,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
FloatVariable(value=0.0, name="float_var"),
|
||||||
|
SegmentType.FLOAT,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
IntegerVariable(value=0, name="int_var"),
|
||||||
|
SegmentType.INTEGER,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
ObjectVariable(value={}, name="object_var"),
|
||||||
|
SegmentType.OBJECT,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
FileVariable(value=file1, name="file_var"),
|
||||||
|
SegmentType.FILE,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
SecretVariable(value="secret", name="secret_var"),
|
||||||
|
SegmentType.SECRET,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
ArrayAnyVariable(value=[0, 0.0, ""], name="array_any_var"),
|
||||||
|
SegmentType.ARRAY_ANY,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
ArrayStringVariable(value=[""], name="array_string_var"),
|
||||||
|
SegmentType.ARRAY_STRING,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
ArrayNumberVariable(value=[0, 0.0], name="array_number_var"),
|
||||||
|
SegmentType.ARRAY_NUMBER,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
ArrayObjectVariable(value=[{}], name="array_object_var"),
|
||||||
|
SegmentType.ARRAY_OBJECT,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
ArrayFileVariable(value=[file1, file2], name="array_file_var"),
|
||||||
|
SegmentType.ARRAY_FILE,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for test_case in cases:
|
||||||
|
variable = test_case.variable
|
||||||
|
assert get_segment_discriminator(variable) == test_case.expected_segment_type, (
|
||||||
|
f"get_segment_discriminator failed for type {type(variable)}"
|
||||||
|
)
|
||||||
|
model_dict = variable.model_dump(mode="json")
|
||||||
|
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
|
||||||
|
f"get_segment_discriminator failed for serialized form of type {type(variable)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_invlaid_value_for_discriminator(self):
|
||||||
|
# Test invalid cases
|
||||||
|
assert get_segment_discriminator({"value_type": "invalid"}) is None
|
||||||
|
assert get_segment_discriminator({}) is None
|
||||||
|
assert get_segment_discriminator("not_a_dict") is None
|
||||||
|
assert get_segment_discriminator(42) is None
|
||||||
|
assert get_segment_discriminator(object) is None
|
||||||
|
|||||||
60
api/tests/unit_tests/core/variables/test_segment_type.py
Normal file
60
api/tests/unit_tests/core/variables/test_segment_type.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from core.variables.types import SegmentType
|
||||||
|
|
||||||
|
|
||||||
|
class TestSegmentTypeIsArrayType:
|
||||||
|
"""
|
||||||
|
Test class for SegmentType.is_array_type method.
|
||||||
|
|
||||||
|
Provides comprehensive coverage of all SegmentType values to ensure
|
||||||
|
correct identification of array and non-array types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_is_array_type(self):
|
||||||
|
"""
|
||||||
|
Test that all SegmentType enum values are covered in our test cases.
|
||||||
|
|
||||||
|
Ensures comprehensive coverage by verifying that every SegmentType
|
||||||
|
value is tested for the is_array_type method.
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
all_segment_types = set(SegmentType)
|
||||||
|
expected_array_types = [
|
||||||
|
SegmentType.ARRAY_ANY,
|
||||||
|
SegmentType.ARRAY_STRING,
|
||||||
|
SegmentType.ARRAY_NUMBER,
|
||||||
|
SegmentType.ARRAY_OBJECT,
|
||||||
|
SegmentType.ARRAY_FILE,
|
||||||
|
]
|
||||||
|
expected_non_array_types = [
|
||||||
|
SegmentType.INTEGER,
|
||||||
|
SegmentType.FLOAT,
|
||||||
|
SegmentType.NUMBER,
|
||||||
|
SegmentType.STRING,
|
||||||
|
SegmentType.OBJECT,
|
||||||
|
SegmentType.SECRET,
|
||||||
|
SegmentType.FILE,
|
||||||
|
SegmentType.NONE,
|
||||||
|
SegmentType.GROUP,
|
||||||
|
]
|
||||||
|
|
||||||
|
for seg_type in expected_array_types:
|
||||||
|
assert seg_type.is_array_type()
|
||||||
|
|
||||||
|
for seg_type in expected_non_array_types:
|
||||||
|
assert not seg_type.is_array_type()
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
covered_types = set(expected_array_types) | set(expected_non_array_types)
|
||||||
|
assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests"
|
||||||
|
|
||||||
|
def test_all_enum_values_are_supported(self):
|
||||||
|
"""
|
||||||
|
Test that all enum values are supported and return boolean values.
|
||||||
|
|
||||||
|
Validates that every SegmentType enum value can be processed by
|
||||||
|
is_array_type method and returns a boolean value.
|
||||||
|
"""
|
||||||
|
enum_values: list[SegmentType] = list(SegmentType)
|
||||||
|
for seg_type in enum_values:
|
||||||
|
is_array = seg_type.is_array_type()
|
||||||
|
assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"
|
||||||
@ -11,6 +11,7 @@ from core.variables import (
|
|||||||
SegmentType,
|
SegmentType,
|
||||||
StringVariable,
|
StringVariable,
|
||||||
)
|
)
|
||||||
|
from core.variables.variables import Variable
|
||||||
|
|
||||||
|
|
||||||
def test_frozen_variables():
|
def test_frozen_variables():
|
||||||
@ -75,7 +76,7 @@ def test_object_variable_to_object():
|
|||||||
|
|
||||||
|
|
||||||
def test_variable_to_object():
|
def test_variable_to_object():
|
||||||
var = StringVariable(name="text", value="text")
|
var: Variable = StringVariable(name="text", value="text")
|
||||||
assert var.to_object() == "text"
|
assert var.to_object() == "text"
|
||||||
var = IntegerVariable(name="integer", value=42)
|
var = IntegerVariable(name="integer", value=42)
|
||||||
assert var.to_object() == 42
|
assert var.to_object() == 42
|
||||||
|
|||||||
@ -0,0 +1,146 @@
|
|||||||
|
import time
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
|
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_graph_runtime_state() -> GraphRuntimeState:
|
||||||
|
"""Factory function to create a GraphRuntimeState with non-empty values for testing."""
|
||||||
|
# Create a variable pool with system variables
|
||||||
|
system_vars = SystemVariable(
|
||||||
|
user_id="test_user_123",
|
||||||
|
app_id="test_app_456",
|
||||||
|
workflow_id="test_workflow_789",
|
||||||
|
workflow_execution_id="test_execution_001",
|
||||||
|
query="test query",
|
||||||
|
conversation_id="test_conv_123",
|
||||||
|
dialogue_count=5,
|
||||||
|
)
|
||||||
|
variable_pool = VariablePool(system_variables=system_vars)
|
||||||
|
|
||||||
|
# Add some variables to the variable pool
|
||||||
|
variable_pool.add(["test_node", "test_var"], "test_value")
|
||||||
|
variable_pool.add(["another_node", "another_var"], 42)
|
||||||
|
|
||||||
|
# Create LLM usage with realistic values
|
||||||
|
llm_usage = LLMUsage(
|
||||||
|
prompt_tokens=150,
|
||||||
|
prompt_unit_price=Decimal("0.001"),
|
||||||
|
prompt_price_unit=Decimal(1000),
|
||||||
|
prompt_price=Decimal("0.15"),
|
||||||
|
completion_tokens=75,
|
||||||
|
completion_unit_price=Decimal("0.002"),
|
||||||
|
completion_price_unit=Decimal(1000),
|
||||||
|
completion_price=Decimal("0.15"),
|
||||||
|
total_tokens=225,
|
||||||
|
total_price=Decimal("0.30"),
|
||||||
|
currency="USD",
|
||||||
|
latency=1.25,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create runtime route state with some node states
|
||||||
|
node_run_state = RuntimeRouteState()
|
||||||
|
node_state = node_run_state.create_node_state("test_node_1")
|
||||||
|
node_run_state.add_route(node_state.id, "target_node_id")
|
||||||
|
|
||||||
|
return GraphRuntimeState(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
start_at=time.perf_counter(),
|
||||||
|
total_tokens=100,
|
||||||
|
llm_usage=llm_usage,
|
||||||
|
outputs={
|
||||||
|
"string_output": "test result",
|
||||||
|
"int_output": 42,
|
||||||
|
"float_output": 3.14,
|
||||||
|
"list_output": ["item1", "item2", "item3"],
|
||||||
|
"dict_output": {"key1": "value1", "key2": 123},
|
||||||
|
"nested_dict": {"level1": {"level2": ["nested", "list", 456]}},
|
||||||
|
},
|
||||||
|
node_run_steps=5,
|
||||||
|
node_run_state=node_run_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_round_trip_serialization():
|
||||||
|
"""Test basic round-trip serialization ensures GraphRuntimeState values remain unchanged."""
|
||||||
|
# Create a state with non-empty values
|
||||||
|
original_state = create_test_graph_runtime_state()
|
||||||
|
|
||||||
|
# Serialize to JSON and deserialize back
|
||||||
|
json_data = original_state.model_dump_json()
|
||||||
|
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
|
||||||
|
|
||||||
|
# Core test: ensure the round-trip preserves all values
|
||||||
|
assert deserialized_state == original_state
|
||||||
|
|
||||||
|
# Serialize to JSON and deserialize back
|
||||||
|
dict_data = original_state.model_dump(mode="python")
|
||||||
|
deserialized_state = GraphRuntimeState.model_validate(dict_data)
|
||||||
|
assert deserialized_state == original_state
|
||||||
|
|
||||||
|
# Serialize to JSON and deserialize back
|
||||||
|
dict_data = original_state.model_dump(mode="json")
|
||||||
|
deserialized_state = GraphRuntimeState.model_validate(dict_data)
|
||||||
|
assert deserialized_state == original_state
|
||||||
|
|
||||||
|
|
||||||
|
def test_outputs_field_round_trip():
|
||||||
|
"""Test the problematic outputs field maintains values through round-trip serialization."""
|
||||||
|
original_state = create_test_graph_runtime_state()
|
||||||
|
|
||||||
|
# Serialize and deserialize
|
||||||
|
json_data = original_state.model_dump_json()
|
||||||
|
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
|
||||||
|
|
||||||
|
# Verify the outputs field specifically maintains its values
|
||||||
|
assert deserialized_state.outputs == original_state.outputs
|
||||||
|
assert deserialized_state == original_state
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_outputs_round_trip():
|
||||||
|
"""Test round-trip serialization with empty outputs field."""
|
||||||
|
variable_pool = VariablePool.empty()
|
||||||
|
original_state = GraphRuntimeState(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
start_at=time.perf_counter(),
|
||||||
|
outputs={}, # Empty outputs
|
||||||
|
)
|
||||||
|
|
||||||
|
json_data = original_state.model_dump_json()
|
||||||
|
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
|
||||||
|
|
||||||
|
assert deserialized_state == original_state
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_usage_round_trip():
|
||||||
|
# Create LLM usage with specific decimal values
|
||||||
|
llm_usage = LLMUsage(
|
||||||
|
prompt_tokens=100,
|
||||||
|
prompt_unit_price=Decimal("0.0015"),
|
||||||
|
prompt_price_unit=Decimal(1000),
|
||||||
|
prompt_price=Decimal("0.15"),
|
||||||
|
completion_tokens=50,
|
||||||
|
completion_unit_price=Decimal("0.003"),
|
||||||
|
completion_price_unit=Decimal(1000),
|
||||||
|
completion_price=Decimal("0.15"),
|
||||||
|
total_tokens=150,
|
||||||
|
total_price=Decimal("0.30"),
|
||||||
|
currency="USD",
|
||||||
|
latency=2.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
json_data = llm_usage.model_dump_json()
|
||||||
|
deserialized = LLMUsage.model_validate_json(json_data)
|
||||||
|
assert deserialized == llm_usage
|
||||||
|
|
||||||
|
dict_data = llm_usage.model_dump(mode="python")
|
||||||
|
deserialized = LLMUsage.model_validate(dict_data)
|
||||||
|
assert deserialized == llm_usage
|
||||||
|
|
||||||
|
dict_data = llm_usage.model_dump(mode="json")
|
||||||
|
deserialized = LLMUsage.model_validate(dict_data)
|
||||||
|
assert deserialized == llm_usage
|
||||||
@ -0,0 +1,401 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
|
||||||
|
|
||||||
|
_TEST_DATETIME = datetime(2024, 1, 15, 10, 30, 45)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRouteNodeStateSerialization:
|
||||||
|
"""Test cases for RouteNodeState Pydantic serialization/deserialization."""
|
||||||
|
|
||||||
|
def _test_route_node_state(self):
|
||||||
|
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
|
||||||
|
|
||||||
|
node_run_result = NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
inputs={"input_key": "input_value"},
|
||||||
|
outputs={"output_key": "output_value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
node_state = RouteNodeState(
|
||||||
|
node_id="comprehensive_test_node",
|
||||||
|
start_at=_TEST_DATETIME,
|
||||||
|
finished_at=_TEST_DATETIME,
|
||||||
|
status=RouteNodeState.Status.SUCCESS,
|
||||||
|
node_run_result=node_run_result,
|
||||||
|
index=5,
|
||||||
|
paused_at=_TEST_DATETIME,
|
||||||
|
paused_by="user_123",
|
||||||
|
failed_reason="test_reason",
|
||||||
|
)
|
||||||
|
return node_state
|
||||||
|
|
||||||
|
def test_route_node_state_comprehensive_field_validation(self):
|
||||||
|
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
|
||||||
|
node_state = self._test_route_node_state()
|
||||||
|
serialized = node_state.model_dump()
|
||||||
|
|
||||||
|
# Comprehensive validation of all RouteNodeState fields
|
||||||
|
assert serialized["node_id"] == "comprehensive_test_node"
|
||||||
|
assert serialized["status"] == RouteNodeState.Status.SUCCESS
|
||||||
|
assert serialized["start_at"] == _TEST_DATETIME
|
||||||
|
assert serialized["finished_at"] == _TEST_DATETIME
|
||||||
|
assert serialized["paused_at"] == _TEST_DATETIME
|
||||||
|
assert serialized["paused_by"] == "user_123"
|
||||||
|
assert serialized["failed_reason"] == "test_reason"
|
||||||
|
assert serialized["index"] == 5
|
||||||
|
assert "id" in serialized
|
||||||
|
assert isinstance(serialized["id"], str)
|
||||||
|
uuid.UUID(serialized["id"]) # Validate UUID format
|
||||||
|
|
||||||
|
# Validate nested NodeRunResult structure
|
||||||
|
assert serialized["node_run_result"] is not None
|
||||||
|
assert serialized["node_run_result"]["status"] == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||||
|
assert serialized["node_run_result"]["inputs"] == {"input_key": "input_value"}
|
||||||
|
assert serialized["node_run_result"]["outputs"] == {"output_key": "output_value"}
|
||||||
|
|
||||||
|
def test_route_node_state_minimal_required_fields(self):
|
||||||
|
"""Test RouteNodeState with only required fields, focusing on defaults."""
|
||||||
|
node_state = RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME)
|
||||||
|
|
||||||
|
serialized = node_state.model_dump()
|
||||||
|
|
||||||
|
# Focus on required fields and default values (not re-testing all fields)
|
||||||
|
assert serialized["node_id"] == "minimal_node"
|
||||||
|
assert serialized["start_at"] == _TEST_DATETIME
|
||||||
|
assert serialized["status"] == RouteNodeState.Status.RUNNING # Default status
|
||||||
|
assert serialized["index"] == 1 # Default index
|
||||||
|
assert serialized["node_run_result"] is None # Default None
|
||||||
|
json = node_state.model_dump_json()
|
||||||
|
deserialized = RouteNodeState.model_validate_json(json)
|
||||||
|
assert deserialized == node_state
|
||||||
|
|
||||||
|
def test_route_node_state_deserialization_from_dict(self):
|
||||||
|
"""Test RouteNodeState deserialization from dictionary data."""
|
||||||
|
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
|
||||||
|
test_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
dict_data = {
|
||||||
|
"id": test_id,
|
||||||
|
"node_id": "deserialized_node",
|
||||||
|
"start_at": test_datetime,
|
||||||
|
"status": "success",
|
||||||
|
"finished_at": test_datetime,
|
||||||
|
"index": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
node_state = RouteNodeState.model_validate(dict_data)
|
||||||
|
|
||||||
|
# Focus on deserialization accuracy
|
||||||
|
assert node_state.id == test_id
|
||||||
|
assert node_state.node_id == "deserialized_node"
|
||||||
|
assert node_state.start_at == test_datetime
|
||||||
|
assert node_state.status == RouteNodeState.Status.SUCCESS
|
||||||
|
assert node_state.finished_at == test_datetime
|
||||||
|
assert node_state.index == 3
|
||||||
|
|
||||||
|
def test_route_node_state_round_trip_consistency(self):
|
||||||
|
node_states = (
|
||||||
|
self._test_route_node_state(),
|
||||||
|
RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME),
|
||||||
|
)
|
||||||
|
for node_state in node_states:
|
||||||
|
json = node_state.model_dump_json()
|
||||||
|
deserialized = RouteNodeState.model_validate_json(json)
|
||||||
|
assert deserialized == node_state
|
||||||
|
|
||||||
|
dict_ = node_state.model_dump(mode="python")
|
||||||
|
deserialized = RouteNodeState.model_validate(dict_)
|
||||||
|
assert deserialized == node_state
|
||||||
|
|
||||||
|
dict_ = node_state.model_dump(mode="json")
|
||||||
|
deserialized = RouteNodeState.model_validate(dict_)
|
||||||
|
assert deserialized == node_state
|
||||||
|
|
||||||
|
|
||||||
|
class TestRouteNodeStateEnumSerialization:
|
||||||
|
"""Dedicated tests for RouteNodeState Status enum serialization behavior."""
|
||||||
|
|
||||||
|
def test_status_enum_model_dump_behavior(self):
|
||||||
|
"""Test Status enum serialization in model_dump() returns enum objects."""
|
||||||
|
|
||||||
|
for status_enum in RouteNodeState.Status:
|
||||||
|
node_state = RouteNodeState(node_id="enum_test", start_at=_TEST_DATETIME, status=status_enum)
|
||||||
|
serialized = node_state.model_dump(mode="python")
|
||||||
|
assert serialized["status"] == status_enum
|
||||||
|
serialized = node_state.model_dump(mode="json")
|
||||||
|
assert serialized["status"] == status_enum.value
|
||||||
|
|
||||||
|
def test_status_enum_json_serialization_behavior(self):
|
||||||
|
"""Test Status enum serialization in JSON returns string values."""
|
||||||
|
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
|
||||||
|
|
||||||
|
enum_to_string_mapping = {
|
||||||
|
RouteNodeState.Status.RUNNING: "running",
|
||||||
|
RouteNodeState.Status.SUCCESS: "success",
|
||||||
|
RouteNodeState.Status.FAILED: "failed",
|
||||||
|
RouteNodeState.Status.PAUSED: "paused",
|
||||||
|
RouteNodeState.Status.EXCEPTION: "exception",
|
||||||
|
}
|
||||||
|
|
||||||
|
for status_enum, expected_string in enum_to_string_mapping.items():
|
||||||
|
node_state = RouteNodeState(node_id="json_enum_test", start_at=test_datetime, status=status_enum)
|
||||||
|
|
||||||
|
json_data = json.loads(node_state.model_dump_json())
|
||||||
|
assert json_data["status"] == expected_string
|
||||||
|
|
||||||
|
def test_status_enum_deserialization_from_string(self):
|
||||||
|
"""Test Status enum deserialization from string values."""
|
||||||
|
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
|
||||||
|
|
||||||
|
string_to_enum_mapping = {
|
||||||
|
"running": RouteNodeState.Status.RUNNING,
|
||||||
|
"success": RouteNodeState.Status.SUCCESS,
|
||||||
|
"failed": RouteNodeState.Status.FAILED,
|
||||||
|
"paused": RouteNodeState.Status.PAUSED,
|
||||||
|
"exception": RouteNodeState.Status.EXCEPTION,
|
||||||
|
}
|
||||||
|
|
||||||
|
for status_string, expected_enum in string_to_enum_mapping.items():
|
||||||
|
dict_data = {
|
||||||
|
"node_id": "enum_deserialize_test",
|
||||||
|
"start_at": test_datetime,
|
||||||
|
"status": status_string,
|
||||||
|
}
|
||||||
|
|
||||||
|
node_state = RouteNodeState.model_validate(dict_data)
|
||||||
|
assert node_state.status == expected_enum
|
||||||
|
|
||||||
|
|
||||||
|
class TestRuntimeRouteStateSerialization:
|
||||||
|
"""Test cases for RuntimeRouteState Pydantic serialization/deserialization."""
|
||||||
|
|
||||||
|
_NODE1_ID = "node_1"
|
||||||
|
_ROUTE_STATE1_ID = str(uuid.uuid4())
|
||||||
|
_NODE2_ID = "node_2"
|
||||||
|
_ROUTE_STATE2_ID = str(uuid.uuid4())
|
||||||
|
_NODE3_ID = "node_3"
|
||||||
|
_ROUTE_STATE3_ID = str(uuid.uuid4())
|
||||||
|
|
||||||
|
def _get_runtime_route_state(self):
|
||||||
|
# Create node states with different configurations
|
||||||
|
node_state_1 = RouteNodeState(
|
||||||
|
id=self._ROUTE_STATE1_ID,
|
||||||
|
node_id=self._NODE1_ID,
|
||||||
|
start_at=_TEST_DATETIME,
|
||||||
|
index=1,
|
||||||
|
)
|
||||||
|
node_state_2 = RouteNodeState(
|
||||||
|
id=self._ROUTE_STATE2_ID,
|
||||||
|
node_id=self._NODE2_ID,
|
||||||
|
start_at=_TEST_DATETIME,
|
||||||
|
status=RouteNodeState.Status.SUCCESS,
|
||||||
|
finished_at=_TEST_DATETIME,
|
||||||
|
index=2,
|
||||||
|
)
|
||||||
|
node_state_3 = RouteNodeState(
|
||||||
|
id=self._ROUTE_STATE3_ID,
|
||||||
|
node_id=self._NODE3_ID,
|
||||||
|
start_at=_TEST_DATETIME,
|
||||||
|
status=RouteNodeState.Status.FAILED,
|
||||||
|
failed_reason="Test failure",
|
||||||
|
index=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
runtime_state = RuntimeRouteState(
|
||||||
|
routes={node_state_1.id: [node_state_2.id, node_state_3.id], node_state_2.id: [node_state_3.id]},
|
||||||
|
node_state_mapping={
|
||||||
|
node_state_1.id: node_state_1,
|
||||||
|
node_state_2.id: node_state_2,
|
||||||
|
node_state_3.id: node_state_3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return runtime_state
|
||||||
|
|
||||||
|
def test_runtime_route_state_comprehensive_structure_validation(self):
|
||||||
|
"""Test comprehensive RuntimeRouteState serialization with full structure validation."""
|
||||||
|
|
||||||
|
runtime_state = self._get_runtime_route_state()
|
||||||
|
serialized = runtime_state.model_dump()
|
||||||
|
|
||||||
|
# Comprehensive validation of RuntimeRouteState structure
|
||||||
|
assert "routes" in serialized
|
||||||
|
assert "node_state_mapping" in serialized
|
||||||
|
assert isinstance(serialized["routes"], dict)
|
||||||
|
assert isinstance(serialized["node_state_mapping"], dict)
|
||||||
|
|
||||||
|
# Validate routes dictionary structure and content
|
||||||
|
assert len(serialized["routes"]) == 2
|
||||||
|
assert self._ROUTE_STATE1_ID in serialized["routes"]
|
||||||
|
assert self._ROUTE_STATE2_ID in serialized["routes"]
|
||||||
|
assert serialized["routes"][self._ROUTE_STATE1_ID] == [self._ROUTE_STATE2_ID, self._ROUTE_STATE3_ID]
|
||||||
|
assert serialized["routes"][self._ROUTE_STATE2_ID] == [self._ROUTE_STATE3_ID]
|
||||||
|
|
||||||
|
# Validate node_state_mapping dictionary structure and content
|
||||||
|
assert len(serialized["node_state_mapping"]) == 3
|
||||||
|
for state_id in [
|
||||||
|
self._ROUTE_STATE1_ID,
|
||||||
|
self._ROUTE_STATE2_ID,
|
||||||
|
self._ROUTE_STATE3_ID,
|
||||||
|
]:
|
||||||
|
assert state_id in serialized["node_state_mapping"]
|
||||||
|
node_data = serialized["node_state_mapping"][state_id]
|
||||||
|
node_state = runtime_state.node_state_mapping[state_id]
|
||||||
|
assert node_data["node_id"] == node_state.node_id
|
||||||
|
assert node_data["status"] == node_state.status
|
||||||
|
assert node_data["index"] == node_state.index
|
||||||
|
|
||||||
|
def test_runtime_route_state_empty_collections(self):
|
||||||
|
"""Test RuntimeRouteState with empty collections, focusing on default behavior."""
|
||||||
|
runtime_state = RuntimeRouteState()
|
||||||
|
serialized = runtime_state.model_dump()
|
||||||
|
|
||||||
|
# Focus on default empty collection behavior
|
||||||
|
assert serialized["routes"] == {}
|
||||||
|
assert serialized["node_state_mapping"] == {}
|
||||||
|
assert isinstance(serialized["routes"], dict)
|
||||||
|
assert isinstance(serialized["node_state_mapping"], dict)
|
||||||
|
|
||||||
|
def test_runtime_route_state_json_serialization_structure(self):
|
||||||
|
"""Test RuntimeRouteState JSON serialization structure."""
|
||||||
|
node_state = RouteNodeState(node_id="json_node", start_at=_TEST_DATETIME)
|
||||||
|
|
||||||
|
runtime_state = RuntimeRouteState(
|
||||||
|
routes={"source": ["target1", "target2"]}, node_state_mapping={node_state.id: node_state}
|
||||||
|
)
|
||||||
|
|
||||||
|
json_str = runtime_state.model_dump_json()
|
||||||
|
json_data = json.loads(json_str)
|
||||||
|
|
||||||
|
# Focus on JSON structure validation
|
||||||
|
assert isinstance(json_str, str)
|
||||||
|
assert isinstance(json_data, dict)
|
||||||
|
assert "routes" in json_data
|
||||||
|
assert "node_state_mapping" in json_data
|
||||||
|
assert json_data["routes"]["source"] == ["target1", "target2"]
|
||||||
|
assert node_state.id in json_data["node_state_mapping"]
|
||||||
|
|
||||||
|
def test_runtime_route_state_deserialization_from_dict(self):
|
||||||
|
"""Test RuntimeRouteState deserialization from dictionary data."""
|
||||||
|
node_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
dict_data = {
|
||||||
|
"routes": {"source_node": ["target_node_1", "target_node_2"]},
|
||||||
|
"node_state_mapping": {
|
||||||
|
node_id: {
|
||||||
|
"id": node_id,
|
||||||
|
"node_id": "test_node",
|
||||||
|
"start_at": _TEST_DATETIME,
|
||||||
|
"status": "running",
|
||||||
|
"index": 1,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
runtime_state = RuntimeRouteState.model_validate(dict_data)
|
||||||
|
|
||||||
|
# Focus on deserialization accuracy
|
||||||
|
assert runtime_state.routes == {"source_node": ["target_node_1", "target_node_2"]}
|
||||||
|
assert len(runtime_state.node_state_mapping) == 1
|
||||||
|
assert node_id in runtime_state.node_state_mapping
|
||||||
|
|
||||||
|
deserialized_node = runtime_state.node_state_mapping[node_id]
|
||||||
|
assert deserialized_node.node_id == "test_node"
|
||||||
|
assert deserialized_node.status == RouteNodeState.Status.RUNNING
|
||||||
|
assert deserialized_node.index == 1
|
||||||
|
|
||||||
|
def test_runtime_route_state_round_trip_consistency(self):
|
||||||
|
"""Test RuntimeRouteState round-trip serialization consistency."""
|
||||||
|
original = self._get_runtime_route_state()
|
||||||
|
|
||||||
|
# Dictionary round trip
|
||||||
|
dict_data = original.model_dump(mode="python")
|
||||||
|
reconstructed = RuntimeRouteState.model_validate(dict_data)
|
||||||
|
assert reconstructed == original
|
||||||
|
|
||||||
|
dict_data = original.model_dump(mode="json")
|
||||||
|
reconstructed = RuntimeRouteState.model_validate(dict_data)
|
||||||
|
assert reconstructed == original
|
||||||
|
|
||||||
|
# JSON round trip
|
||||||
|
json_str = original.model_dump_json()
|
||||||
|
json_reconstructed = RuntimeRouteState.model_validate_json(json_str)
|
||||||
|
assert json_reconstructed == original
|
||||||
|
|
||||||
|
|
||||||
|
class TestSerializationEdgeCases:
|
||||||
|
"""Test edge cases and error conditions for serialization/deserialization."""
|
||||||
|
|
||||||
|
def test_invalid_status_deserialization(self):
|
||||||
|
"""Test deserialization with invalid status values."""
|
||||||
|
test_datetime = _TEST_DATETIME
|
||||||
|
invalid_data = {
|
||||||
|
"node_id": "invalid_test",
|
||||||
|
"start_at": test_datetime,
|
||||||
|
"status": "invalid_status",
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
RouteNodeState.model_validate(invalid_data)
|
||||||
|
assert "status" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_missing_required_fields_deserialization(self):
|
||||||
|
"""Test deserialization with missing required fields."""
|
||||||
|
incomplete_data = {"id": str(uuid.uuid4())}
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
RouteNodeState.model_validate(incomplete_data)
|
||||||
|
error_str = str(exc_info.value)
|
||||||
|
assert "node_id" in error_str or "start_at" in error_str
|
||||||
|
|
||||||
|
def test_invalid_datetime_deserialization(self):
|
||||||
|
"""Test deserialization with invalid datetime values."""
|
||||||
|
invalid_data = {
|
||||||
|
"node_id": "datetime_test",
|
||||||
|
"start_at": "invalid_datetime",
|
||||||
|
"status": "running",
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
RouteNodeState.model_validate(invalid_data)
|
||||||
|
assert "start_at" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_invalid_routes_structure_deserialization(self):
|
||||||
|
"""Test RuntimeRouteState deserialization with invalid routes structure."""
|
||||||
|
invalid_data = {
|
||||||
|
"routes": "invalid_routes_structure", # Should be dict
|
||||||
|
"node_state_mapping": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
RuntimeRouteState.model_validate(invalid_data)
|
||||||
|
assert "routes" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_timezone_handling_in_datetime_fields(self):
|
||||||
|
"""Test timezone handling in datetime field serialization."""
|
||||||
|
utc_datetime = datetime.now(UTC)
|
||||||
|
naive_datetime = utc_datetime.replace(tzinfo=None)
|
||||||
|
|
||||||
|
node_state = RouteNodeState(node_id="timezone_test", start_at=naive_datetime)
|
||||||
|
dict_ = node_state.model_dump()
|
||||||
|
|
||||||
|
assert dict_["start_at"] == naive_datetime
|
||||||
|
|
||||||
|
# Test round trip
|
||||||
|
reconstructed = RouteNodeState.model_validate(dict_)
|
||||||
|
assert reconstructed.start_at == naive_datetime
|
||||||
|
assert reconstructed.start_at.tzinfo is None
|
||||||
|
|
||||||
|
json = node_state.model_dump_json()
|
||||||
|
|
||||||
|
reconstructed = RouteNodeState.model_validate_json(json)
|
||||||
|
assert reconstructed.start_at == naive_datetime
|
||||||
|
assert reconstructed.start_at.tzinfo is None
|
||||||
@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||||||
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
|
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
BaseNodeEvent,
|
BaseNodeEvent,
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
@ -27,6 +26,7 @@ from core.workflow.nodes.code.code_node import CodeNode
|
|||||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||||
from core.workflow.nodes.llm.node import LLMNode
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
@ -171,7 +171,8 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
|||||||
graph = Graph.init(graph_config=graph_config)
|
graph = Graph.init(graph_config=graph_config)
|
||||||
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
|
system_variables=SystemVariable(user_id="aaa", app_id="1", workflow_id="1", files=[]),
|
||||||
|
user_inputs={"query": "hi"},
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||||
@ -293,12 +294,12 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
|
|||||||
graph = Graph.init(graph_config=graph_config)
|
graph = Graph.init(graph_config=graph_config)
|
||||||
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={
|
system_variables=SystemVariable(
|
||||||
SystemVariableKey.QUERY: "what's the weather in SF",
|
user_id="aaa",
|
||||||
SystemVariableKey.FILES: [],
|
files=[],
|
||||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
query="what's the weather in SF",
|
||||||
SystemVariableKey.USER_ID: "aaa",
|
conversation_id="abababa",
|
||||||
},
|
),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -474,12 +475,12 @@ def test_run_branch(mock_close, mock_remove):
|
|||||||
graph = Graph.init(graph_config=graph_config)
|
graph = Graph.init(graph_config=graph_config)
|
||||||
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={
|
system_variables=SystemVariable(
|
||||||
SystemVariableKey.QUERY: "hi",
|
user_id="aaa",
|
||||||
SystemVariableKey.FILES: [],
|
files=[],
|
||||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
query="hi",
|
||||||
SystemVariableKey.USER_ID: "aaa",
|
conversation_id="abababa",
|
||||||
},
|
),
|
||||||
user_inputs={"uid": "takato"},
|
user_inputs={"uid": "takato"},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -804,18 +805,22 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
|
|||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(
|
pool = VariablePool(
|
||||||
system_variables={
|
system_variables=SystemVariable(
|
||||||
SystemVariableKey.QUERY: "dify",
|
user_id="1",
|
||||||
SystemVariableKey.FILES: [],
|
files=[],
|
||||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
query="dify",
|
||||||
SystemVariableKey.USER_ID: "1",
|
conversation_id="abababa",
|
||||||
},
|
),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
)
|
)
|
||||||
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
|
system_variables=SystemVariable(
|
||||||
|
user_id="aaa",
|
||||||
|
files=[],
|
||||||
|
),
|
||||||
|
user_inputs={"query": "hi"},
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||||
|
|||||||
@ -5,11 +5,11 @@ from unittest.mock import MagicMock
|
|||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
@ -51,7 +51,7 @@ def test_execute_answer():
|
|||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(
|
pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from collections.abc import Generator
|
|||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
NodeRunStartedEvent,
|
NodeRunStartedEvent,
|
||||||
@ -15,6 +14,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
|
|||||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.nodes.enums import NodeType
|
||||||
from core.workflow.nodes.start.entities import StartNodeData
|
from core.workflow.nodes.start.entities import StartNodeData
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
|
|
||||||
|
|
||||||
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
|
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
|
||||||
@ -180,12 +180,12 @@ def test_process():
|
|||||||
graph = Graph.init(graph_config=graph_config)
|
graph = Graph.init(graph_config=graph_config)
|
||||||
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={
|
system_variables=SystemVariable(
|
||||||
SystemVariableKey.QUERY: "what's the weather in SF",
|
user_id="aaa",
|
||||||
SystemVariableKey.FILES: [],
|
files=[],
|
||||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
query="what's the weather in SF",
|
||||||
SystemVariableKey.USER_ID: "aaa",
|
conversation_id="abababa",
|
||||||
},
|
),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -7,12 +7,13 @@ from core.workflow.nodes.http_request import (
|
|||||||
)
|
)
|
||||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
||||||
from core.workflow.nodes.http_request.executor import Executor
|
from core.workflow.nodes.http_request.executor import Executor
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
|
|
||||||
|
|
||||||
def test_executor_with_json_body_and_number_variable():
|
def test_executor_with_json_body_and_number_variable():
|
||||||
# Prepare the variable pool
|
# Prepare the variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={},
|
system_variables=SystemVariable.empty(),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
)
|
)
|
||||||
variable_pool.add(["pre_node_id", "number"], 42)
|
variable_pool.add(["pre_node_id", "number"], 42)
|
||||||
@ -65,7 +66,7 @@ def test_executor_with_json_body_and_number_variable():
|
|||||||
def test_executor_with_json_body_and_object_variable():
|
def test_executor_with_json_body_and_object_variable():
|
||||||
# Prepare the variable pool
|
# Prepare the variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={},
|
system_variables=SystemVariable.empty(),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
)
|
)
|
||||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||||
@ -120,7 +121,7 @@ def test_executor_with_json_body_and_object_variable():
|
|||||||
def test_executor_with_json_body_and_nested_object_variable():
|
def test_executor_with_json_body_and_nested_object_variable():
|
||||||
# Prepare the variable pool
|
# Prepare the variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={},
|
system_variables=SystemVariable.empty(),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
)
|
)
|
||||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||||
@ -174,7 +175,7 @@ def test_executor_with_json_body_and_nested_object_variable():
|
|||||||
|
|
||||||
|
|
||||||
def test_extract_selectors_from_template_with_newline():
|
def test_extract_selectors_from_template_with_newline():
|
||||||
variable_pool = VariablePool()
|
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||||
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
|
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
|
||||||
node_data = HttpRequestNodeData(
|
node_data = HttpRequestNodeData(
|
||||||
title="Test JSON Body with Nested Object Variable",
|
title="Test JSON Body with Nested Object Variable",
|
||||||
@ -201,7 +202,7 @@ def test_extract_selectors_from_template_with_newline():
|
|||||||
def test_executor_with_form_data():
|
def test_executor_with_form_data():
|
||||||
# Prepare the variable pool
|
# Prepare the variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={},
|
system_variables=SystemVariable.empty(),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
)
|
)
|
||||||
variable_pool.add(["pre_node_id", "text_field"], "Hello, World!")
|
variable_pool.add(["pre_node_id", "text_field"], "Hello, World!")
|
||||||
@ -280,7 +281,11 @@ def test_init_headers():
|
|||||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||||
)
|
)
|
||||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||||
return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
|
return Executor(
|
||||||
|
node_data=node_data,
|
||||||
|
timeout=timeout,
|
||||||
|
variable_pool=VariablePool(system_variables=SystemVariable.empty()),
|
||||||
|
)
|
||||||
|
|
||||||
executor = create_executor("aa\n cc:")
|
executor = create_executor("aa\n cc:")
|
||||||
executor._init_headers()
|
executor._init_headers()
|
||||||
@ -310,7 +315,11 @@ def test_init_params():
|
|||||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||||
)
|
)
|
||||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||||
return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
|
return Executor(
|
||||||
|
node_data=node_data,
|
||||||
|
timeout=timeout,
|
||||||
|
variable_pool=VariablePool(system_variables=SystemVariable.empty()),
|
||||||
|
)
|
||||||
|
|
||||||
# Test basic key-value pairs
|
# Test basic key-value pairs
|
||||||
executor = create_executor("key1:value1\nkey2:value2")
|
executor = create_executor("key1:value1\nkey2:value2")
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from core.workflow.nodes.http_request import (
|
|||||||
HttpRequestNodeBody,
|
HttpRequestNodeBody,
|
||||||
HttpRequestNodeData,
|
HttpRequestNodeData,
|
||||||
)
|
)
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
@ -40,7 +41,7 @@ def test_http_request_node_binary_file(monkeypatch):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={},
|
system_variables=SystemVariable.empty(),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
)
|
)
|
||||||
variable_pool.add(
|
variable_pool.add(
|
||||||
@ -128,7 +129,7 @@ def test_http_request_node_form_with_file(monkeypatch):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={},
|
system_variables=SystemVariable.empty(),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
)
|
)
|
||||||
variable_pool.add(
|
variable_pool.add(
|
||||||
@ -223,7 +224,7 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={},
|
system_variables=SystemVariable.empty(),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from core.variables.segments import ArrayAnySegment, ArrayStringSegment
|
|||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
@ -15,6 +14,7 @@ from core.workflow.nodes.event import RunCompletedEvent
|
|||||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode
|
from core.workflow.nodes.iteration.entities import ErrorHandleMode
|
||||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
@ -151,12 +151,12 @@ def test_run():
|
|||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(
|
pool = VariablePool(
|
||||||
system_variables={
|
system_variables=SystemVariable(
|
||||||
SystemVariableKey.QUERY: "dify",
|
user_id="1",
|
||||||
SystemVariableKey.FILES: [],
|
files=[],
|
||||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
query="dify",
|
||||||
SystemVariableKey.USER_ID: "1",
|
conversation_id="abababa",
|
||||||
},
|
),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
)
|
)
|
||||||
@ -368,12 +368,12 @@ def test_run_parallel():
|
|||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(
|
pool = VariablePool(
|
||||||
system_variables={
|
system_variables=SystemVariable(
|
||||||
SystemVariableKey.QUERY: "dify",
|
user_id="1",
|
||||||
SystemVariableKey.FILES: [],
|
files=[],
|
||||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
query="dify",
|
||||||
SystemVariableKey.USER_ID: "1",
|
conversation_id="abababa",
|
||||||
},
|
),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
)
|
)
|
||||||
@ -584,12 +584,12 @@ def test_iteration_run_in_parallel_mode():
|
|||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(
|
pool = VariablePool(
|
||||||
system_variables={
|
system_variables=SystemVariable(
|
||||||
SystemVariableKey.QUERY: "dify",
|
user_id="1",
|
||||||
SystemVariableKey.FILES: [],
|
files=[],
|
||||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
query="dify",
|
||||||
SystemVariableKey.USER_ID: "1",
|
conversation_id="abababa",
|
||||||
},
|
),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
)
|
)
|
||||||
@ -808,12 +808,12 @@ def test_iteration_run_error_handle():
|
|||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(
|
pool = VariablePool(
|
||||||
system_variables={
|
system_variables=SystemVariable(
|
||||||
SystemVariableKey.QUERY: "dify",
|
user_id="1",
|
||||||
SystemVariableKey.FILES: [],
|
files=[],
|
||||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
query="dify",
|
||||||
SystemVariableKey.USER_ID: "1",
|
conversation_id="abababa",
|
||||||
},
|
),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -36,6 +36,7 @@ from core.workflow.nodes.llm.entities import (
|
|||||||
)
|
)
|
||||||
from core.workflow.nodes.llm.file_saver import LLMFileSaver
|
from core.workflow.nodes.llm.file_saver import LLMFileSaver
|
||||||
from core.workflow.nodes.llm.node import LLMNode
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.provider import ProviderType
|
from models.provider import ProviderType
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
@ -104,7 +105,7 @@ def graph() -> Graph:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def graph_runtime_state() -> GraphRuntimeState:
|
def graph_runtime_state() -> GraphRuntimeState:
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={},
|
system_variables=SystemVariable.empty(),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
)
|
)
|
||||||
return GraphRuntimeState(
|
return GraphRuntimeState(
|
||||||
@ -181,7 +182,7 @@ def test_fetch_files_with_file_segment():
|
|||||||
related_id="1",
|
related_id="1",
|
||||||
storage_key="",
|
storage_key="",
|
||||||
)
|
)
|
||||||
variable_pool = VariablePool()
|
variable_pool = VariablePool.empty()
|
||||||
variable_pool.add(["sys", "files"], file)
|
variable_pool.add(["sys", "files"], file)
|
||||||
|
|
||||||
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
||||||
@ -209,7 +210,7 @@ def test_fetch_files_with_array_file_segment():
|
|||||||
storage_key="",
|
storage_key="",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
variable_pool = VariablePool()
|
variable_pool = VariablePool.empty()
|
||||||
variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
||||||
|
|
||||||
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
||||||
@ -217,7 +218,7 @@ def test_fetch_files_with_array_file_segment():
|
|||||||
|
|
||||||
|
|
||||||
def test_fetch_files_with_none_segment():
|
def test_fetch_files_with_none_segment():
|
||||||
variable_pool = VariablePool()
|
variable_pool = VariablePool.empty()
|
||||||
variable_pool.add(["sys", "files"], NoneSegment())
|
variable_pool.add(["sys", "files"], NoneSegment())
|
||||||
|
|
||||||
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
||||||
@ -225,7 +226,7 @@ def test_fetch_files_with_none_segment():
|
|||||||
|
|
||||||
|
|
||||||
def test_fetch_files_with_array_any_segment():
|
def test_fetch_files_with_array_any_segment():
|
||||||
variable_pool = VariablePool()
|
variable_pool = VariablePool.empty()
|
||||||
variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
||||||
|
|
||||||
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
||||||
@ -233,7 +234,7 @@ def test_fetch_files_with_array_any_segment():
|
|||||||
|
|
||||||
|
|
||||||
def test_fetch_files_with_non_existent_variable():
|
def test_fetch_files_with_non_existent_variable():
|
||||||
variable_pool = VariablePool()
|
variable_pool = VariablePool.empty()
|
||||||
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
|
|||||||
@ -5,11 +5,11 @@ from unittest.mock import MagicMock
|
|||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
@ -53,7 +53,7 @@ def test_execute_answer():
|
|||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[],
|
conversation_variables=[],
|
||||||
|
|||||||
@ -5,7 +5,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||||||
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
|
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
GraphRunPartialSucceededEvent,
|
GraphRunPartialSucceededEvent,
|
||||||
NodeRunExceptionEvent,
|
NodeRunExceptionEvent,
|
||||||
@ -17,6 +16,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
|
|||||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||||
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
|
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
|
||||||
from core.workflow.nodes.llm.node import LLMNode
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
@ -167,12 +167,12 @@ class ContinueOnErrorTestHelper:
|
|||||||
"""Helper method to create a graph engine instance for testing"""
|
"""Helper method to create a graph engine instance for testing"""
|
||||||
graph = Graph.init(graph_config=graph_config)
|
graph = Graph.init(graph_config=graph_config)
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={
|
system_variables=SystemVariable(
|
||||||
SystemVariableKey.QUERY: "clear",
|
user_id="aaa",
|
||||||
SystemVariableKey.FILES: [],
|
files=[],
|
||||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
query="clear",
|
||||||
SystemVariableKey.USER_ID: "aaa",
|
conversation_id="abababa",
|
||||||
},
|
),
|
||||||
user_inputs=user_inputs or {"uid": "takato"},
|
user_inputs=user_inputs or {"uid": "takato"},
|
||||||
)
|
)
|
||||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||||
|
|||||||
@ -7,12 +7,12 @@ from core.file import File, FileTransferMethod, FileType
|
|||||||
from core.variables import ArrayFileSegment
|
from core.variables import ArrayFileSegment
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
|
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
@ -37,9 +37,7 @@ def test_execute_if_else_result_true():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(
|
pool = VariablePool(system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={})
|
||||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={}
|
|
||||||
)
|
|
||||||
pool.add(["start", "array_contains"], ["ab", "def"])
|
pool.add(["start", "array_contains"], ["ab", "def"])
|
||||||
pool.add(["start", "array_not_contains"], ["ac", "def"])
|
pool.add(["start", "array_not_contains"], ["ac", "def"])
|
||||||
pool.add(["start", "contains"], "cabcde")
|
pool.add(["start", "contains"], "cabcde")
|
||||||
@ -157,7 +155,7 @@ def test_execute_if_else_result_false():
|
|||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(
|
pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from core.workflow.nodes.enums import ErrorStrategy
|
|||||||
from core.workflow.nodes.event import RunCompletedEvent
|
from core.workflow.nodes.event import RunCompletedEvent
|
||||||
from core.workflow.nodes.tool import ToolNode
|
from core.workflow.nodes.tool import ToolNode
|
||||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from models import UserFrom, WorkflowType
|
from models import UserFrom, WorkflowType
|
||||||
|
|
||||||
|
|
||||||
@ -34,7 +35,7 @@ def _create_tool_node():
|
|||||||
version="1",
|
version="1",
|
||||||
)
|
)
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={},
|
system_variables=SystemVariable.empty(),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
)
|
)
|
||||||
node = ToolNode(
|
node = ToolNode(
|
||||||
|
|||||||
@ -7,12 +7,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||||||
from core.variables import ArrayStringVariable, StringVariable
|
from core.variables import ArrayStringVariable, StringVariable
|
||||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
|
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
|
||||||
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
|
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ def test_overwrite_string_variable():
|
|||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
|
system_variables=SystemVariable(conversation_id=conversation_id),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[conversation_variable],
|
conversation_variables=[conversation_variable],
|
||||||
@ -165,7 +165,7 @@ def test_append_variable_to_array():
|
|||||||
conversation_id = str(uuid.uuid4())
|
conversation_id = str(uuid.uuid4())
|
||||||
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
|
system_variables=SystemVariable(conversation_id=conversation_id),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[conversation_variable],
|
conversation_variables=[conversation_variable],
|
||||||
@ -256,7 +256,7 @@ def test_clear_array():
|
|||||||
|
|
||||||
conversation_id = str(uuid.uuid4())
|
conversation_id = str(uuid.uuid4())
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
|
system_variables=SystemVariable(conversation_id=conversation_id),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[conversation_variable],
|
conversation_variables=[conversation_variable],
|
||||||
|
|||||||
@ -5,12 +5,12 @@ from uuid import uuid4
|
|||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.variables import ArrayStringVariable
|
from core.variables import ArrayStringVariable
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode
|
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode
|
||||||
from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation
|
from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
@ -109,7 +109,7 @@ def test_remove_first_from_array():
|
|||||||
)
|
)
|
||||||
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
system_variables=SystemVariable(conversation_id="conversation_id"),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[conversation_variable],
|
conversation_variables=[conversation_variable],
|
||||||
@ -196,7 +196,7 @@ def test_remove_last_from_array():
|
|||||||
)
|
)
|
||||||
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
system_variables=SystemVariable(conversation_id="conversation_id"),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[conversation_variable],
|
conversation_variables=[conversation_variable],
|
||||||
@ -275,7 +275,7 @@ def test_remove_first_from_empty_array():
|
|||||||
)
|
)
|
||||||
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
system_variables=SystemVariable(conversation_id="conversation_id"),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[conversation_variable],
|
conversation_variables=[conversation_variable],
|
||||||
@ -354,7 +354,7 @@ def test_remove_last_from_empty_array():
|
|||||||
)
|
)
|
||||||
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
system_variables=SystemVariable(conversation_id="conversation_id"),
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[conversation_variable],
|
conversation_variables=[conversation_variable],
|
||||||
|
|||||||
251
api/tests/unit_tests/core/workflow/test_system_variable.py
Normal file
251
api/tests/unit_tests/core/workflow/test_system_variable.py
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from core.file.enums import FileTransferMethod, FileType
|
||||||
|
from core.file.models import File
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
|
|
||||||
|
# Test data constants for SystemVariable serialization tests
|
||||||
|
VALID_BASE_DATA: dict[str, Any] = {
|
||||||
|
"user_id": "a20f06b1-8703-45ab-937c-860a60072113",
|
||||||
|
"app_id": "661bed75-458d-49c9-b487-fda0762677b9",
|
||||||
|
"workflow_id": "d31f2136-b292-4ae0-96d4-1e77894a4f43",
|
||||||
|
}
|
||||||
|
|
||||||
|
COMPLETE_VALID_DATA: dict[str, Any] = {
|
||||||
|
**VALID_BASE_DATA,
|
||||||
|
"query": "test query",
|
||||||
|
"files": [],
|
||||||
|
"conversation_id": "91f1eb7d-69f4-4d7b-b82f-4003d51744b9",
|
||||||
|
"dialogue_count": 5,
|
||||||
|
"workflow_run_id": "eb4704b5-2274-47f2-bfcd-0452daa82cb5",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_file() -> File:
|
||||||
|
"""Create a test File object for serialization tests."""
|
||||||
|
return File(
|
||||||
|
tenant_id="test-tenant-id",
|
||||||
|
type=FileType.DOCUMENT,
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
related_id="test-file-id",
|
||||||
|
filename="test.txt",
|
||||||
|
extension=".txt",
|
||||||
|
mime_type="text/plain",
|
||||||
|
size=1024,
|
||||||
|
storage_key="test-storage-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSystemVariableSerialization:
|
||||||
|
"""Focused tests for SystemVariable serialization/deserialization logic."""
|
||||||
|
|
||||||
|
def test_basic_deserialization(self):
|
||||||
|
"""Test successful deserialization from JSON structure with all fields correctly mapped."""
|
||||||
|
# Test with complete data
|
||||||
|
system_var = SystemVariable(**COMPLETE_VALID_DATA)
|
||||||
|
|
||||||
|
# Verify all fields are correctly mapped
|
||||||
|
assert system_var.user_id == COMPLETE_VALID_DATA["user_id"]
|
||||||
|
assert system_var.app_id == COMPLETE_VALID_DATA["app_id"]
|
||||||
|
assert system_var.workflow_id == COMPLETE_VALID_DATA["workflow_id"]
|
||||||
|
assert system_var.query == COMPLETE_VALID_DATA["query"]
|
||||||
|
assert system_var.conversation_id == COMPLETE_VALID_DATA["conversation_id"]
|
||||||
|
assert system_var.dialogue_count == COMPLETE_VALID_DATA["dialogue_count"]
|
||||||
|
assert system_var.workflow_execution_id == COMPLETE_VALID_DATA["workflow_run_id"]
|
||||||
|
assert system_var.files == []
|
||||||
|
|
||||||
|
# Test with minimal data (only required fields)
|
||||||
|
minimal_var = SystemVariable(**VALID_BASE_DATA)
|
||||||
|
assert minimal_var.user_id == VALID_BASE_DATA["user_id"]
|
||||||
|
assert minimal_var.app_id == VALID_BASE_DATA["app_id"]
|
||||||
|
assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"]
|
||||||
|
assert minimal_var.query is None
|
||||||
|
assert minimal_var.conversation_id is None
|
||||||
|
assert minimal_var.dialogue_count is None
|
||||||
|
assert minimal_var.workflow_execution_id is None
|
||||||
|
assert minimal_var.files == []
|
||||||
|
|
||||||
|
def test_alias_handling(self):
|
||||||
|
"""Test workflow_execution_id vs workflow_run_id alias resolution - core deserialization logic."""
|
||||||
|
workflow_id = "eb4704b5-2274-47f2-bfcd-0452daa82cb5"
|
||||||
|
|
||||||
|
# Test workflow_run_id only (preferred alias)
|
||||||
|
data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
|
||||||
|
system_var1 = SystemVariable(**data_run_id)
|
||||||
|
assert system_var1.workflow_execution_id == workflow_id
|
||||||
|
|
||||||
|
# Test workflow_execution_id only (direct field name)
|
||||||
|
data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
|
||||||
|
system_var2 = SystemVariable(**data_execution_id)
|
||||||
|
assert system_var2.workflow_execution_id == workflow_id
|
||||||
|
|
||||||
|
# Test both present - workflow_run_id should take precedence
|
||||||
|
data_both = {
|
||||||
|
**VALID_BASE_DATA,
|
||||||
|
"workflow_execution_id": "should-be-ignored",
|
||||||
|
"workflow_run_id": workflow_id,
|
||||||
|
}
|
||||||
|
system_var3 = SystemVariable(**data_both)
|
||||||
|
assert system_var3.workflow_execution_id == workflow_id
|
||||||
|
|
||||||
|
# Test neither present - should be None
|
||||||
|
system_var4 = SystemVariable(**VALID_BASE_DATA)
|
||||||
|
assert system_var4.workflow_execution_id is None
|
||||||
|
|
||||||
|
def test_serialization_round_trip(self):
|
||||||
|
"""Test that serialize → deserialize produces the same result with alias handling."""
|
||||||
|
# Create original SystemVariable
|
||||||
|
original = SystemVariable(**COMPLETE_VALID_DATA)
|
||||||
|
|
||||||
|
# Serialize to dict
|
||||||
|
serialized = original.model_dump(mode="json")
|
||||||
|
|
||||||
|
# Verify alias is used in serialization (workflow_run_id, not workflow_execution_id)
|
||||||
|
assert "workflow_run_id" in serialized
|
||||||
|
assert "workflow_execution_id" not in serialized
|
||||||
|
assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
|
||||||
|
|
||||||
|
# Deserialize back
|
||||||
|
deserialized = SystemVariable(**serialized)
|
||||||
|
|
||||||
|
# Verify all fields match after round-trip
|
||||||
|
assert deserialized.user_id == original.user_id
|
||||||
|
assert deserialized.app_id == original.app_id
|
||||||
|
assert deserialized.workflow_id == original.workflow_id
|
||||||
|
assert deserialized.query == original.query
|
||||||
|
assert deserialized.conversation_id == original.conversation_id
|
||||||
|
assert deserialized.dialogue_count == original.dialogue_count
|
||||||
|
assert deserialized.workflow_execution_id == original.workflow_execution_id
|
||||||
|
assert list(deserialized.files) == list(original.files)
|
||||||
|
|
||||||
|
def test_json_round_trip(self):
|
||||||
|
"""Test JSON serialization/deserialization consistency with proper structure."""
|
||||||
|
# Create original SystemVariable
|
||||||
|
original = SystemVariable(**COMPLETE_VALID_DATA)
|
||||||
|
|
||||||
|
# Serialize to JSON string
|
||||||
|
json_str = original.model_dump_json()
|
||||||
|
|
||||||
|
# Parse JSON and verify structure
|
||||||
|
json_data = json.loads(json_str)
|
||||||
|
assert "workflow_run_id" in json_data
|
||||||
|
assert "workflow_execution_id" not in json_data
|
||||||
|
assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
|
||||||
|
|
||||||
|
# Deserialize from JSON data
|
||||||
|
deserialized = SystemVariable(**json_data)
|
||||||
|
|
||||||
|
# Verify key fields match after JSON round-trip
|
||||||
|
assert deserialized.workflow_execution_id == original.workflow_execution_id
|
||||||
|
assert deserialized.user_id == original.user_id
|
||||||
|
assert deserialized.app_id == original.app_id
|
||||||
|
assert deserialized.workflow_id == original.workflow_id
|
||||||
|
|
||||||
|
def test_files_field_deserialization(self):
|
||||||
|
"""Test deserialization with File objects in the files field - SystemVariable specific logic."""
|
||||||
|
# Test with empty files list
|
||||||
|
data_empty = {**VALID_BASE_DATA, "files": []}
|
||||||
|
system_var_empty = SystemVariable(**data_empty)
|
||||||
|
assert system_var_empty.files == []
|
||||||
|
|
||||||
|
# Test with single File object
|
||||||
|
test_file = create_test_file()
|
||||||
|
data_single = {**VALID_BASE_DATA, "files": [test_file]}
|
||||||
|
system_var_single = SystemVariable(**data_single)
|
||||||
|
assert len(system_var_single.files) == 1
|
||||||
|
assert system_var_single.files[0].filename == "test.txt"
|
||||||
|
assert system_var_single.files[0].tenant_id == "test-tenant-id"
|
||||||
|
|
||||||
|
# Test with multiple File objects
|
||||||
|
file1 = File(
|
||||||
|
tenant_id="tenant1",
|
||||||
|
type=FileType.DOCUMENT,
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
related_id="file1",
|
||||||
|
filename="doc1.txt",
|
||||||
|
storage_key="key1",
|
||||||
|
)
|
||||||
|
file2 = File(
|
||||||
|
tenant_id="tenant2",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url="https://example.com/image.jpg",
|
||||||
|
filename="image.jpg",
|
||||||
|
storage_key="key2",
|
||||||
|
)
|
||||||
|
|
||||||
|
data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]}
|
||||||
|
system_var_multiple = SystemVariable(**data_multiple)
|
||||||
|
assert len(system_var_multiple.files) == 2
|
||||||
|
assert system_var_multiple.files[0].filename == "doc1.txt"
|
||||||
|
assert system_var_multiple.files[1].filename == "image.jpg"
|
||||||
|
|
||||||
|
# Verify files field serialization/deserialization
|
||||||
|
serialized = system_var_multiple.model_dump(mode="json")
|
||||||
|
deserialized = SystemVariable(**serialized)
|
||||||
|
assert len(deserialized.files) == 2
|
||||||
|
assert deserialized.files[0].filename == "doc1.txt"
|
||||||
|
assert deserialized.files[1].filename == "image.jpg"
|
||||||
|
|
||||||
|
def test_alias_serialization_consistency(self):
|
||||||
|
"""Test that alias handling works consistently in both serialization directions."""
|
||||||
|
workflow_id = "test-workflow-id"
|
||||||
|
|
||||||
|
# Create with workflow_run_id (alias)
|
||||||
|
data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
|
||||||
|
system_var = SystemVariable(**data_with_alias)
|
||||||
|
|
||||||
|
# Serialize and verify alias is used
|
||||||
|
serialized = system_var.model_dump()
|
||||||
|
assert serialized["workflow_run_id"] == workflow_id
|
||||||
|
assert "workflow_execution_id" not in serialized
|
||||||
|
|
||||||
|
# Deserialize and verify field mapping
|
||||||
|
deserialized = SystemVariable(**serialized)
|
||||||
|
assert deserialized.workflow_execution_id == workflow_id
|
||||||
|
|
||||||
|
# Test JSON serialization path
|
||||||
|
json_serialized = json.loads(system_var.model_dump_json())
|
||||||
|
assert json_serialized["workflow_run_id"] == workflow_id
|
||||||
|
assert "workflow_execution_id" not in json_serialized
|
||||||
|
|
||||||
|
json_deserialized = SystemVariable(**json_serialized)
|
||||||
|
assert json_deserialized.workflow_execution_id == workflow_id
|
||||||
|
|
||||||
|
def test_model_validator_serialization_logic(self):
|
||||||
|
"""Test the custom model validator behavior for serialization scenarios."""
|
||||||
|
workflow_id = "test-workflow-execution-id"
|
||||||
|
|
||||||
|
# Test direct instantiation with workflow_execution_id (should work)
|
||||||
|
data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
|
||||||
|
system_var1 = SystemVariable(**data1)
|
||||||
|
assert system_var1.workflow_execution_id == workflow_id
|
||||||
|
|
||||||
|
# Test serialization of the above (should use alias)
|
||||||
|
serialized1 = system_var1.model_dump()
|
||||||
|
assert "workflow_run_id" in serialized1
|
||||||
|
assert serialized1["workflow_run_id"] == workflow_id
|
||||||
|
|
||||||
|
# Test both present - workflow_run_id takes precedence (validator logic)
|
||||||
|
data2 = {
|
||||||
|
**VALID_BASE_DATA,
|
||||||
|
"workflow_execution_id": "should-be-removed",
|
||||||
|
"workflow_run_id": workflow_id,
|
||||||
|
}
|
||||||
|
system_var2 = SystemVariable(**data2)
|
||||||
|
assert system_var2.workflow_execution_id == workflow_id
|
||||||
|
|
||||||
|
# Verify serialization consistency
|
||||||
|
serialized2 = system_var2.model_dump()
|
||||||
|
assert serialized2["workflow_run_id"] == workflow_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor_with_extra_key():
|
||||||
|
# Test that SystemVariable should forbid extra keys
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
# This should fail because there is an unexpected key.
|
||||||
|
SystemVariable(invalid_key=1) # type: ignore
|
||||||
@ -1,17 +1,43 @@
|
|||||||
|
import uuid
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from core.file import File, FileTransferMethod, FileType
|
from core.file import File, FileTransferMethod, FileType
|
||||||
from core.variables import FileSegment, StringSegment
|
from core.variables import FileSegment, StringSegment
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
|
from core.variables.segments import (
|
||||||
|
ArrayAnySegment,
|
||||||
|
ArrayFileSegment,
|
||||||
|
ArrayNumberSegment,
|
||||||
|
ArrayObjectSegment,
|
||||||
|
ArrayStringSegment,
|
||||||
|
FloatSegment,
|
||||||
|
IntegerSegment,
|
||||||
|
NoneSegment,
|
||||||
|
ObjectSegment,
|
||||||
|
)
|
||||||
|
from core.variables.variables import (
|
||||||
|
ArrayNumberVariable,
|
||||||
|
ArrayObjectVariable,
|
||||||
|
ArrayStringVariable,
|
||||||
|
FloatVariable,
|
||||||
|
IntegerVariable,
|
||||||
|
ObjectVariable,
|
||||||
|
StringVariable,
|
||||||
|
VariableUnion,
|
||||||
|
)
|
||||||
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.system_variable import SystemVariable
|
||||||
from factories.variable_factory import build_segment, segment_to_variable
|
from factories.variable_factory import build_segment, segment_to_variable
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def pool():
|
def pool():
|
||||||
return VariablePool(system_variables={}, user_inputs={})
|
return VariablePool(
|
||||||
|
system_variables=SystemVariable(user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"),
|
||||||
|
user_inputs={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -52,18 +78,28 @@ def test_use_long_selector(pool):
|
|||||||
|
|
||||||
class TestVariablePool:
|
class TestVariablePool:
|
||||||
def test_constructor(self):
|
def test_constructor(self):
|
||||||
pool = VariablePool()
|
# Test with minimal required SystemVariable
|
||||||
|
minimal_system_vars = SystemVariable(
|
||||||
|
user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"
|
||||||
|
)
|
||||||
|
pool = VariablePool(system_variables=minimal_system_vars)
|
||||||
|
|
||||||
|
# Test with all parameters
|
||||||
pool = VariablePool(
|
pool = VariablePool(
|
||||||
variable_dictionary={},
|
variable_dictionary={},
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
system_variables={},
|
system_variables=minimal_system_vars,
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[],
|
conversation_variables=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Test with more complex SystemVariable
|
||||||
|
complex_system_vars = SystemVariable(
|
||||||
|
user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"
|
||||||
|
)
|
||||||
pool = VariablePool(
|
pool = VariablePool(
|
||||||
user_inputs={"key": "value"},
|
user_inputs={"key": "value"},
|
||||||
system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"},
|
system_variables=complex_system_vars,
|
||||||
environment_variables=[
|
environment_variables=[
|
||||||
segment_to_variable(
|
segment_to_variable(
|
||||||
segment=build_segment(1),
|
segment=build_segment(1),
|
||||||
@ -80,6 +116,323 @@ class TestVariablePool:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_constructor_with_invalid_system_variable_key(self):
|
def test_get_system_variables(self):
|
||||||
with pytest.raises(ValidationError):
|
sys_var = SystemVariable(
|
||||||
VariablePool(system_variables={"invalid_key": "value"}) # type: ignore
|
user_id="test_user_id",
|
||||||
|
app_id="test_app_id",
|
||||||
|
workflow_id="test_workflow_id",
|
||||||
|
workflow_execution_id="test_execution_123",
|
||||||
|
query="test query",
|
||||||
|
conversation_id="test_conv_id",
|
||||||
|
dialogue_count=5,
|
||||||
|
)
|
||||||
|
pool = VariablePool(system_variables=sys_var)
|
||||||
|
|
||||||
|
kv = [
|
||||||
|
("user_id", sys_var.user_id),
|
||||||
|
("app_id", sys_var.app_id),
|
||||||
|
("workflow_id", sys_var.workflow_id),
|
||||||
|
("workflow_run_id", sys_var.workflow_execution_id),
|
||||||
|
("query", sys_var.query),
|
||||||
|
("conversation_id", sys_var.conversation_id),
|
||||||
|
("dialogue_count", sys_var.dialogue_count),
|
||||||
|
]
|
||||||
|
for key, expected_value in kv:
|
||||||
|
segment = pool.get([SYSTEM_VARIABLE_NODE_ID, key])
|
||||||
|
assert segment is not None
|
||||||
|
assert segment.value == expected_value
|
||||||
|
|
||||||
|
|
||||||
|
class TestVariablePoolSerialization:
|
||||||
|
"""Test cases for VariablePool serialization and deserialization using Pydantic's built-in methods.
|
||||||
|
|
||||||
|
These tests focus exclusively on serialization/deserialization logic to ensure that
|
||||||
|
VariablePool data can be properly serialized to dictionaries/JSON and reconstructed
|
||||||
|
while preserving all data integrity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_NODE1_ID = "node_1"
|
||||||
|
_NODE2_ID = "node_2"
|
||||||
|
_NODE3_ID = "node_3"
|
||||||
|
|
||||||
|
def _create_pool_without_file(self):
|
||||||
|
# Create comprehensive system variables
|
||||||
|
system_vars = SystemVariable(
|
||||||
|
user_id="test_user_id",
|
||||||
|
app_id="test_app_id",
|
||||||
|
workflow_id="test_workflow_id",
|
||||||
|
workflow_execution_id="test_execution_123",
|
||||||
|
query="test query",
|
||||||
|
conversation_id="test_conv_id",
|
||||||
|
dialogue_count=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create environment variables with all types including ArrayFileVariable
|
||||||
|
env_vars: list[VariableUnion] = [
|
||||||
|
StringVariable(
|
||||||
|
id="env_string_id",
|
||||||
|
name="env_string",
|
||||||
|
value="env_string_value",
|
||||||
|
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_string"],
|
||||||
|
),
|
||||||
|
IntegerVariable(
|
||||||
|
id="env_integer_id",
|
||||||
|
name="env_integer",
|
||||||
|
value=1,
|
||||||
|
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_integer"],
|
||||||
|
),
|
||||||
|
FloatVariable(
|
||||||
|
id="env_float_id",
|
||||||
|
name="env_float",
|
||||||
|
value=1.0,
|
||||||
|
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_float"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create conversation variables with complex data
|
||||||
|
conv_vars: list[VariableUnion] = [
|
||||||
|
StringVariable(
|
||||||
|
id="conv_string_id",
|
||||||
|
name="conv_string",
|
||||||
|
value="conv_string_value",
|
||||||
|
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_string"],
|
||||||
|
),
|
||||||
|
IntegerVariable(
|
||||||
|
id="conv_integer_id",
|
||||||
|
name="conv_integer",
|
||||||
|
value=1,
|
||||||
|
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_integer"],
|
||||||
|
),
|
||||||
|
FloatVariable(
|
||||||
|
id="conv_float_id",
|
||||||
|
name="conv_float",
|
||||||
|
value=1.0,
|
||||||
|
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_float"],
|
||||||
|
),
|
||||||
|
ObjectVariable(
|
||||||
|
id="conv_object_id",
|
||||||
|
name="conv_object",
|
||||||
|
value={"key": "value", "nested": {"data": 123}},
|
||||||
|
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_object"],
|
||||||
|
),
|
||||||
|
ArrayStringVariable(
|
||||||
|
id="conv_array_string_id",
|
||||||
|
name="conv_array_string",
|
||||||
|
value=["conv_array_string_value"],
|
||||||
|
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_string"],
|
||||||
|
),
|
||||||
|
ArrayNumberVariable(
|
||||||
|
id="conv_array_number_id",
|
||||||
|
name="conv_array_number",
|
||||||
|
value=[1, 1.0],
|
||||||
|
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_number"],
|
||||||
|
),
|
||||||
|
ArrayObjectVariable(
|
||||||
|
id="conv_array_object_id",
|
||||||
|
name="conv_array_object",
|
||||||
|
value=[{"a": 1}, {"b": "2"}],
|
||||||
|
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_object"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create comprehensive user inputs
|
||||||
|
user_inputs = {
|
||||||
|
"string_input": "test_value",
|
||||||
|
"number_input": 42,
|
||||||
|
"object_input": {"nested": {"key": "value"}},
|
||||||
|
"array_input": ["item1", "item2", "item3"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create VariablePool
|
||||||
|
pool = VariablePool(
|
||||||
|
system_variables=system_vars,
|
||||||
|
user_inputs=user_inputs,
|
||||||
|
environment_variables=env_vars,
|
||||||
|
conversation_variables=conv_vars,
|
||||||
|
)
|
||||||
|
return pool
|
||||||
|
|
||||||
|
def _add_node_data_to_pool(self, pool: VariablePool, with_file=False):
|
||||||
|
test_file = File(
|
||||||
|
tenant_id="test_tenant_id",
|
||||||
|
type=FileType.DOCUMENT,
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
related_id="test_related_id",
|
||||||
|
remote_url="test_url",
|
||||||
|
filename="test_file.txt",
|
||||||
|
storage_key="test_storage_key",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add various segment types to variable dictionary
|
||||||
|
pool.add((self._NODE1_ID, "string_var"), StringSegment(value="test_string"))
|
||||||
|
pool.add((self._NODE1_ID, "int_var"), IntegerSegment(value=123))
|
||||||
|
pool.add((self._NODE1_ID, "float_var"), FloatSegment(value=45.67))
|
||||||
|
pool.add((self._NODE1_ID, "object_var"), ObjectSegment(value={"test": "data"}))
|
||||||
|
if with_file:
|
||||||
|
pool.add((self._NODE1_ID, "file_var"), FileSegment(value=test_file))
|
||||||
|
pool.add((self._NODE1_ID, "none_var"), NoneSegment())
|
||||||
|
|
||||||
|
# Add array segments including ArrayFileVariable
|
||||||
|
pool.add((self._NODE2_ID, "array_string"), ArrayStringSegment(value=["a", "b", "c"]))
|
||||||
|
pool.add((self._NODE2_ID, "array_number"), ArrayNumberSegment(value=[1, 2, 3]))
|
||||||
|
pool.add((self._NODE2_ID, "array_object"), ArrayObjectSegment(value=[{"a": 1}, {"b": 2}]))
|
||||||
|
if with_file:
|
||||||
|
pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
|
||||||
|
pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
|
||||||
|
|
||||||
|
# Add nested variables
|
||||||
|
pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value"))
|
||||||
|
|
||||||
|
def test_system_variables(self):
|
||||||
|
sys_vars = SystemVariable(
|
||||||
|
user_id="test_user_id",
|
||||||
|
app_id="test_app_id",
|
||||||
|
workflow_id="test_workflow_id",
|
||||||
|
workflow_execution_id="test_execution_123",
|
||||||
|
query="test query",
|
||||||
|
conversation_id="test_conv_id",
|
||||||
|
dialogue_count=5,
|
||||||
|
)
|
||||||
|
pool = VariablePool(system_variables=sys_vars)
|
||||||
|
json = pool.model_dump_json()
|
||||||
|
pool2 = VariablePool.model_validate_json(json)
|
||||||
|
assert pool2.system_variables == sys_vars
|
||||||
|
|
||||||
|
for mode in ["json", "python"]:
|
||||||
|
dict_ = pool.model_dump(mode=mode)
|
||||||
|
pool2 = VariablePool.model_validate(dict_)
|
||||||
|
assert pool2.system_variables == sys_vars
|
||||||
|
|
||||||
|
def test_pool_without_file_vars(self):
|
||||||
|
pool = self._create_pool_without_file()
|
||||||
|
json = pool.model_dump_json()
|
||||||
|
pool2 = pool.model_validate_json(json)
|
||||||
|
assert pool2.system_variables == pool.system_variables
|
||||||
|
assert pool2.conversation_variables == pool.conversation_variables
|
||||||
|
assert pool2.environment_variables == pool.environment_variables
|
||||||
|
assert pool2.user_inputs == pool.user_inputs
|
||||||
|
assert pool2.variable_dictionary == pool.variable_dictionary
|
||||||
|
assert pool2 == pool
|
||||||
|
|
||||||
|
def test_basic_dictionary_round_trip(self):
|
||||||
|
"""Test basic round-trip serialization: model_dump() → model_validate()"""
|
||||||
|
# Create a comprehensive VariablePool with all data types
|
||||||
|
original_pool = self._create_pool_without_file()
|
||||||
|
self._add_node_data_to_pool(original_pool)
|
||||||
|
|
||||||
|
# Serialize to dictionary using Pydantic's model_dump()
|
||||||
|
serialized_data = original_pool.model_dump()
|
||||||
|
|
||||||
|
# Verify serialized data structure
|
||||||
|
assert isinstance(serialized_data, dict)
|
||||||
|
assert "system_variables" in serialized_data
|
||||||
|
assert "user_inputs" in serialized_data
|
||||||
|
assert "environment_variables" in serialized_data
|
||||||
|
assert "conversation_variables" in serialized_data
|
||||||
|
assert "variable_dictionary" in serialized_data
|
||||||
|
|
||||||
|
# Deserialize back using Pydantic's model_validate()
|
||||||
|
reconstructed_pool = VariablePool.model_validate(serialized_data)
|
||||||
|
|
||||||
|
# Verify data integrity is preserved
|
||||||
|
self._assert_pools_equal(original_pool, reconstructed_pool)
|
||||||
|
|
||||||
|
def test_json_round_trip(self):
|
||||||
|
"""Test JSON round-trip serialization: model_dump_json() → model_validate_json()"""
|
||||||
|
# Create a comprehensive VariablePool with all data types
|
||||||
|
original_pool = self._create_pool_without_file()
|
||||||
|
self._add_node_data_to_pool(original_pool)
|
||||||
|
|
||||||
|
# Serialize to JSON string using Pydantic's model_dump_json()
|
||||||
|
json_data = original_pool.model_dump_json()
|
||||||
|
|
||||||
|
# Verify JSON is valid string
|
||||||
|
assert isinstance(json_data, str)
|
||||||
|
assert len(json_data) > 0
|
||||||
|
|
||||||
|
# Deserialize back using Pydantic's model_validate_json()
|
||||||
|
reconstructed_pool = VariablePool.model_validate_json(json_data)
|
||||||
|
|
||||||
|
# Verify data integrity is preserved
|
||||||
|
self._assert_pools_equal(original_pool, reconstructed_pool)
|
||||||
|
|
||||||
|
def test_complex_data_serialization(self):
|
||||||
|
"""Test serialization of complex data structures including ArrayFileVariable"""
|
||||||
|
original_pool = self._create_pool_without_file()
|
||||||
|
self._add_node_data_to_pool(original_pool, with_file=True)
|
||||||
|
|
||||||
|
# Test dictionary round-trip
|
||||||
|
dict_data = original_pool.model_dump()
|
||||||
|
reconstructed_dict = VariablePool.model_validate(dict_data)
|
||||||
|
|
||||||
|
# Test JSON round-trip
|
||||||
|
json_data = original_pool.model_dump_json()
|
||||||
|
reconstructed_json = VariablePool.model_validate_json(json_data)
|
||||||
|
|
||||||
|
# Verify both reconstructed pools are equivalent
|
||||||
|
self._assert_pools_equal(reconstructed_dict, reconstructed_json)
|
||||||
|
# TODO: assert the data for file object...
|
||||||
|
|
||||||
|
def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool) -> None:
|
||||||
|
"""Assert that two VariablePools contain equivalent data"""
|
||||||
|
|
||||||
|
# Compare system variables
|
||||||
|
assert pool1.system_variables == pool2.system_variables
|
||||||
|
|
||||||
|
# Compare user inputs
|
||||||
|
assert dict(pool1.user_inputs) == dict(pool2.user_inputs)
|
||||||
|
|
||||||
|
# Compare environment variables count
|
||||||
|
assert pool1.environment_variables == pool2.environment_variables
|
||||||
|
|
||||||
|
# Compare conversation variables count
|
||||||
|
assert pool1.conversation_variables == pool2.conversation_variables
|
||||||
|
|
||||||
|
# Test key variable retrievals to ensure functionality is preserved
|
||||||
|
test_selectors = [
|
||||||
|
(SYSTEM_VARIABLE_NODE_ID, "user_id"),
|
||||||
|
(SYSTEM_VARIABLE_NODE_ID, "app_id"),
|
||||||
|
(ENVIRONMENT_VARIABLE_NODE_ID, "env_string"),
|
||||||
|
(ENVIRONMENT_VARIABLE_NODE_ID, "env_number"),
|
||||||
|
(CONVERSATION_VARIABLE_NODE_ID, "conv_string"),
|
||||||
|
(self._NODE1_ID, "string_var"),
|
||||||
|
(self._NODE1_ID, "int_var"),
|
||||||
|
(self._NODE1_ID, "float_var"),
|
||||||
|
(self._NODE2_ID, "array_string"),
|
||||||
|
(self._NODE2_ID, "array_number"),
|
||||||
|
(self._NODE3_ID, "nested", "deep", "var"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for selector in test_selectors:
|
||||||
|
val1 = pool1.get(selector)
|
||||||
|
val2 = pool2.get(selector)
|
||||||
|
|
||||||
|
# Both should exist or both should be None
|
||||||
|
assert (val1 is None) == (val2 is None)
|
||||||
|
|
||||||
|
if val1 is not None and val2 is not None:
|
||||||
|
# Values should be equal
|
||||||
|
assert val1.value == val2.value
|
||||||
|
# Value types should be the same (more important than exact class type)
|
||||||
|
assert val1.value_type == val2.value_type
|
||||||
|
|
||||||
|
def test_variable_pool_deserialization_default_dict(self):
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
user_inputs={"a": 1, "b": "2"},
|
||||||
|
system_variables=SystemVariable(workflow_id=str(uuid.uuid4())),
|
||||||
|
environment_variables=[
|
||||||
|
StringVariable(name="str_var", value="a"),
|
||||||
|
],
|
||||||
|
conversation_variables=[IntegerVariable(name="int_var", value=1)],
|
||||||
|
)
|
||||||
|
assert isinstance(variable_pool.variable_dictionary, defaultdict)
|
||||||
|
json = variable_pool.model_dump_json()
|
||||||
|
loaded = VariablePool.model_validate_json(json)
|
||||||
|
assert isinstance(loaded.variable_dictionary, defaultdict)
|
||||||
|
|
||||||
|
loaded.add(["non_exist_node", "a"], 1)
|
||||||
|
|
||||||
|
pool_dict = variable_pool.model_dump()
|
||||||
|
loaded = VariablePool.model_validate(pool_dict)
|
||||||
|
assert isinstance(loaded.variable_dictionary, defaultdict)
|
||||||
|
loaded.add(["non_exist_node", "a"], 1)
|
||||||
|
|||||||
@ -18,10 +18,10 @@ from core.workflow.entities.workflow_node_execution import (
|
|||||||
WorkflowNodeExecutionMetadataKey,
|
WorkflowNodeExecutionMetadataKey,
|
||||||
WorkflowNodeExecutionStatus,
|
WorkflowNodeExecutionStatus,
|
||||||
)
|
)
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||||
from models.enums import CreatorUserRole
|
from models.enums import CreatorUserRole
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
@ -67,14 +67,14 @@ def real_app_generate_entity():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def real_workflow_system_variables():
|
def real_workflow_system_variables():
|
||||||
return {
|
return SystemVariable(
|
||||||
SystemVariableKey.QUERY: "test query",
|
query="test query",
|
||||||
SystemVariableKey.CONVERSATION_ID: "test-conversation-id",
|
conversation_id="test-conversation-id",
|
||||||
SystemVariableKey.USER_ID: "test-user-id",
|
user_id="test-user-id",
|
||||||
SystemVariableKey.APP_ID: "test-app-id",
|
app_id="test-app-id",
|
||||||
SystemVariableKey.WORKFLOW_ID: "test-workflow-id",
|
workflow_id="test-workflow-id",
|
||||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: "test-workflow-run-id",
|
workflow_execution_id="test-workflow-run-id",
|
||||||
}
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -10,7 +10,7 @@ class TestAppendVariablesRecursively:
|
|||||||
|
|
||||||
def test_append_simple_dict_value(self):
|
def test_append_simple_dict_value(self):
|
||||||
"""Test appending a simple dictionary value"""
|
"""Test appending a simple dictionary value"""
|
||||||
pool = VariablePool()
|
pool = VariablePool.empty()
|
||||||
node_id = "test_node"
|
node_id = "test_node"
|
||||||
variable_key_list = ["output"]
|
variable_key_list = ["output"]
|
||||||
variable_value = {"name": "John", "age": 30}
|
variable_value = {"name": "John", "age": 30}
|
||||||
@ -33,7 +33,7 @@ class TestAppendVariablesRecursively:
|
|||||||
|
|
||||||
def test_append_object_segment_value(self):
|
def test_append_object_segment_value(self):
|
||||||
"""Test appending an ObjectSegment value"""
|
"""Test appending an ObjectSegment value"""
|
||||||
pool = VariablePool()
|
pool = VariablePool.empty()
|
||||||
node_id = "test_node"
|
node_id = "test_node"
|
||||||
variable_key_list = ["result"]
|
variable_key_list = ["result"]
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ class TestAppendVariablesRecursively:
|
|||||||
|
|
||||||
def test_append_nested_dict_value(self):
|
def test_append_nested_dict_value(self):
|
||||||
"""Test appending a nested dictionary value"""
|
"""Test appending a nested dictionary value"""
|
||||||
pool = VariablePool()
|
pool = VariablePool.empty()
|
||||||
node_id = "test_node"
|
node_id = "test_node"
|
||||||
variable_key_list = ["data"]
|
variable_key_list = ["data"]
|
||||||
|
|
||||||
@ -97,7 +97,7 @@ class TestAppendVariablesRecursively:
|
|||||||
|
|
||||||
def test_append_non_dict_value(self):
|
def test_append_non_dict_value(self):
|
||||||
"""Test appending a non-dictionary value (should not recurse)"""
|
"""Test appending a non-dictionary value (should not recurse)"""
|
||||||
pool = VariablePool()
|
pool = VariablePool.empty()
|
||||||
node_id = "test_node"
|
node_id = "test_node"
|
||||||
variable_key_list = ["simple"]
|
variable_key_list = ["simple"]
|
||||||
variable_value = "simple_string"
|
variable_value = "simple_string"
|
||||||
@ -114,7 +114,7 @@ class TestAppendVariablesRecursively:
|
|||||||
|
|
||||||
def test_append_segment_non_object_value(self):
|
def test_append_segment_non_object_value(self):
|
||||||
"""Test appending a Segment that is not ObjectSegment (should not recurse)"""
|
"""Test appending a Segment that is not ObjectSegment (should not recurse)"""
|
||||||
pool = VariablePool()
|
pool = VariablePool.empty()
|
||||||
node_id = "test_node"
|
node_id = "test_node"
|
||||||
variable_key_list = ["text"]
|
variable_key_list = ["text"]
|
||||||
variable_value = StringSegment(value="Hello World")
|
variable_value = StringSegment(value="Hello World")
|
||||||
@ -132,7 +132,7 @@ class TestAppendVariablesRecursively:
|
|||||||
|
|
||||||
def test_append_empty_dict_value(self):
|
def test_append_empty_dict_value(self):
|
||||||
"""Test appending an empty dictionary value"""
|
"""Test appending an empty dictionary value"""
|
||||||
pool = VariablePool()
|
pool = VariablePool.empty()
|
||||||
node_id = "test_node"
|
node_id = "test_node"
|
||||||
variable_key_list = ["empty"]
|
variable_key_list = ["empty"]
|
||||||
variable_value: dict[str, Any] = {}
|
variable_value: dict[str, Any] = {}
|
||||||
|
|||||||
@ -505,8 +505,8 @@ def test_build_segment_type_for_scalar():
|
|||||||
size=1000,
|
size=1000,
|
||||||
)
|
)
|
||||||
cases = [
|
cases = [
|
||||||
TestCase(0, SegmentType.NUMBER),
|
TestCase(0, SegmentType.INTEGER),
|
||||||
TestCase(0.0, SegmentType.NUMBER),
|
TestCase(0.0, SegmentType.FLOAT),
|
||||||
TestCase("", SegmentType.STRING),
|
TestCase("", SegmentType.STRING),
|
||||||
TestCase(file, SegmentType.FILE),
|
TestCase(file, SegmentType.FILE),
|
||||||
]
|
]
|
||||||
@ -531,14 +531,14 @@ class TestBuildSegmentWithType:
|
|||||||
result = build_segment_with_type(SegmentType.NUMBER, 42)
|
result = build_segment_with_type(SegmentType.NUMBER, 42)
|
||||||
assert isinstance(result, IntegerSegment)
|
assert isinstance(result, IntegerSegment)
|
||||||
assert result.value == 42
|
assert result.value == 42
|
||||||
assert result.value_type == SegmentType.NUMBER
|
assert result.value_type == SegmentType.INTEGER
|
||||||
|
|
||||||
def test_number_type_float(self):
|
def test_number_type_float(self):
|
||||||
"""Test building a number segment with float value."""
|
"""Test building a number segment with float value."""
|
||||||
result = build_segment_with_type(SegmentType.NUMBER, 3.14)
|
result = build_segment_with_type(SegmentType.NUMBER, 3.14)
|
||||||
assert isinstance(result, FloatSegment)
|
assert isinstance(result, FloatSegment)
|
||||||
assert result.value == 3.14
|
assert result.value == 3.14
|
||||||
assert result.value_type == SegmentType.NUMBER
|
assert result.value_type == SegmentType.FLOAT
|
||||||
|
|
||||||
def test_object_type(self):
|
def test_object_type(self):
|
||||||
"""Test building an object segment with correct type."""
|
"""Test building an object segment with correct type."""
|
||||||
@ -652,14 +652,14 @@ class TestBuildSegmentWithType:
|
|||||||
with pytest.raises(TypeMismatchError) as exc_info:
|
with pytest.raises(TypeMismatchError) as exc_info:
|
||||||
build_segment_with_type(SegmentType.STRING, None)
|
build_segment_with_type(SegmentType.STRING, None)
|
||||||
|
|
||||||
assert "Expected string, but got None" in str(exc_info.value)
|
assert "expected string, but got None" in str(exc_info.value)
|
||||||
|
|
||||||
def test_type_mismatch_empty_list_to_non_array(self):
|
def test_type_mismatch_empty_list_to_non_array(self):
|
||||||
"""Test type mismatch when expecting non-array type but getting empty list."""
|
"""Test type mismatch when expecting non-array type but getting empty list."""
|
||||||
with pytest.raises(TypeMismatchError) as exc_info:
|
with pytest.raises(TypeMismatchError) as exc_info:
|
||||||
build_segment_with_type(SegmentType.STRING, [])
|
build_segment_with_type(SegmentType.STRING, [])
|
||||||
|
|
||||||
assert "Expected string, but got empty list" in str(exc_info.value)
|
assert "expected string, but got empty list" in str(exc_info.value)
|
||||||
|
|
||||||
def test_type_mismatch_object_to_array(self):
|
def test_type_mismatch_object_to_array(self):
|
||||||
"""Test type mismatch when expecting array but getting object."""
|
"""Test type mismatch when expecting array but getting object."""
|
||||||
@ -674,19 +674,19 @@ class TestBuildSegmentWithType:
|
|||||||
# Integer should work
|
# Integer should work
|
||||||
result_int = build_segment_with_type(SegmentType.NUMBER, 42)
|
result_int = build_segment_with_type(SegmentType.NUMBER, 42)
|
||||||
assert isinstance(result_int, IntegerSegment)
|
assert isinstance(result_int, IntegerSegment)
|
||||||
assert result_int.value_type == SegmentType.NUMBER
|
assert result_int.value_type == SegmentType.INTEGER
|
||||||
|
|
||||||
# Float should work
|
# Float should work
|
||||||
result_float = build_segment_with_type(SegmentType.NUMBER, 3.14)
|
result_float = build_segment_with_type(SegmentType.NUMBER, 3.14)
|
||||||
assert isinstance(result_float, FloatSegment)
|
assert isinstance(result_float, FloatSegment)
|
||||||
assert result_float.value_type == SegmentType.NUMBER
|
assert result_float.value_type == SegmentType.FLOAT
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("segment_type", "value", "expected_class"),
|
("segment_type", "value", "expected_class"),
|
||||||
[
|
[
|
||||||
(SegmentType.STRING, "test", StringSegment),
|
(SegmentType.STRING, "test", StringSegment),
|
||||||
(SegmentType.NUMBER, 42, IntegerSegment),
|
(SegmentType.INTEGER, 42, IntegerSegment),
|
||||||
(SegmentType.NUMBER, 3.14, FloatSegment),
|
(SegmentType.FLOAT, 3.14, FloatSegment),
|
||||||
(SegmentType.OBJECT, {}, ObjectSegment),
|
(SegmentType.OBJECT, {}, ObjectSegment),
|
||||||
(SegmentType.NONE, None, NoneSegment),
|
(SegmentType.NONE, None, NoneSegment),
|
||||||
(SegmentType.ARRAY_STRING, [], ArrayStringSegment),
|
(SegmentType.ARRAY_STRING, [], ArrayStringSegment),
|
||||||
@ -857,5 +857,5 @@ class TestBuildSegmentValueErrors:
|
|||||||
# Verify they are processed as integers, not as errors
|
# Verify they are processed as integers, not as errors
|
||||||
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
|
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
|
||||||
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
|
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
|
||||||
assert true_segment.value_type == SegmentType.NUMBER
|
assert true_segment.value_type == SegmentType.INTEGER
|
||||||
assert false_segment.value_type == SegmentType.NUMBER
|
assert false_segment.value_type == SegmentType.INTEGER
|
||||||
|
|||||||
@ -98,7 +98,7 @@ const Question: FC<QuestionProps> = ({
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className='mb-2 flex justify-end last:mb-0'>
|
<div className='mb-2 flex justify-end last:mb-0'>
|
||||||
<div className={cn('group relative mr-4 flex max-w-full items-start pl-14 overflow-x-hidden', isEditing && 'flex-1')}>
|
<div className={cn('group relative mr-4 flex max-w-full items-start overflow-x-hidden pl-14', isEditing && 'flex-1')}>
|
||||||
<div className={cn('mr-2 gap-1', isEditing ? 'hidden' : 'flex')}>
|
<div className={cn('mr-2 gap-1', isEditing ? 'hidden' : 'flex')}>
|
||||||
<div
|
<div
|
||||||
className="absolute hidden gap-0.5 rounded-[10px] border-[0.5px] border-components-actionbar-border bg-components-actionbar-bg p-0.5 shadow-md backdrop-blur-sm group-hover:flex"
|
className="absolute hidden gap-0.5 rounded-[10px] border-[0.5px] border-components-actionbar-border bg-components-actionbar-bg p-0.5 shadow-md backdrop-blur-sm group-hover:flex"
|
||||||
@ -117,7 +117,7 @@ const Question: FC<QuestionProps> = ({
|
|||||||
</div>
|
</div>
|
||||||
<div
|
<div
|
||||||
ref={contentRef}
|
ref={contentRef}
|
||||||
className='w-full rounded-2xl bg-background-gradient-bg-fill-chat-bubble-bg-3 px-4 py-3 text-sm text-text-primary'
|
className='bg-background-gradient-bg-fill-chat-bubble-bg-3 w-full rounded-2xl px-4 py-3 text-sm text-text-primary'
|
||||||
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}
|
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}
|
||||||
>
|
>
|
||||||
{
|
{
|
||||||
|
|||||||
@ -61,7 +61,7 @@ export const ToolIcon = memo(({ providerName }: ToolIconProps) => {
|
|||||||
>
|
>
|
||||||
<div
|
<div
|
||||||
className={classNames(
|
className={classNames(
|
||||||
'size-5 border-[0.5px] border-components-panel-border-subtle bg-background-default-dodge relative flex items-center justify-center rounded-[6px]',
|
'relative flex size-5 items-center justify-center rounded-[6px] border-[0.5px] border-components-panel-border-subtle bg-background-default-dodge',
|
||||||
)}
|
)}
|
||||||
ref={containerRef}
|
ref={containerRef}
|
||||||
>
|
>
|
||||||
@ -73,7 +73,7 @@ export const ToolIcon = memo(({ providerName }: ToolIconProps) => {
|
|||||||
src={icon}
|
src={icon}
|
||||||
alt='tool icon'
|
alt='tool icon'
|
||||||
className={classNames(
|
className={classNames(
|
||||||
'w-full h-full size-3.5 object-cover',
|
'size-3.5 h-full w-full object-cover',
|
||||||
notSuccess && 'opacity-50',
|
notSuccess && 'opacity-50',
|
||||||
)}
|
)}
|
||||||
onError={() => setIconFetchError(true)}
|
onError={() => setIconFetchError(true)}
|
||||||
@ -82,7 +82,7 @@ export const ToolIcon = memo(({ providerName }: ToolIconProps) => {
|
|||||||
if (typeof icon === 'object') {
|
if (typeof icon === 'object') {
|
||||||
return <AppIcon
|
return <AppIcon
|
||||||
className={classNames(
|
className={classNames(
|
||||||
'w-full h-full size-3.5 object-cover',
|
'size-3.5 h-full w-full object-cover',
|
||||||
notSuccess && 'opacity-50',
|
notSuccess && 'opacity-50',
|
||||||
)}
|
)}
|
||||||
icon={icon?.content}
|
icon={icon?.content}
|
||||||
|
|||||||
Reference in New Issue
Block a user