Files
dify/api/factories/variable_factory.py
Novice 94b01f6821 Merge commit '92bde350' into sandboxed-agent-rebase
Made-with: Cursor

# Conflicts:
#	api/controllers/console/app/workflow_draft_variable.py
#	api/core/agent/cot_agent_runner.py
#	api/core/agent/cot_chat_agent_runner.py
#	api/core/agent/cot_completion_agent_runner.py
#	api/core/agent/fc_agent_runner.py
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/app/apps/agent_chat/app_runner.py
#	api/core/app/apps/workflow/app_generator.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/entities/app_invoke_entities.py
#	api/core/app/entities/queue_entities.py
#	api/core/llm_generator/output_parser/structured_output.py
#	api/core/workflow/workflow_entry.py
#	api/dify_graph/context/__init__.py
#	api/dify_graph/entities/tool_entities.py
#	api/dify_graph/file/file_manager.py
#	api/dify_graph/graph_engine/response_coordinator/coordinator.py
#	api/dify_graph/graph_events/node.py
#	api/dify_graph/node_events/node.py
#	api/dify_graph/nodes/agent/agent_node.py
#	api/dify_graph/nodes/llm/entities.py
#	api/dify_graph/nodes/llm/llm_utils.py
#	api/dify_graph/nodes/llm/node.py
#	api/dify_graph/nodes/question_classifier/question_classifier_node.py
#	api/dify_graph/runtime/graph_runtime_state.py
#	api/dify_graph/variables/segments.py
#	api/factories/variable_factory.py
#	api/services/variable_truncator.py
#	api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py
#	api/uv.lock
#	web/app/components/app-sidebar/app-info.tsx
#	web/app/components/app-sidebar/app-sidebar-dropdown.tsx
#	web/app/components/app/create-app-modal/index.spec.tsx
#	web/app/components/apps/__tests__/list.spec.tsx
#	web/app/components/apps/app-card.tsx
#	web/app/components/apps/list.tsx
#	web/app/components/header/account-dropdown/compliance.tsx
#	web/app/components/header/account-dropdown/index.tsx
#	web/app/components/header/account-dropdown/support.tsx
#	web/app/components/workflow-app/components/workflow-onboarding-modal/index.tsx
#	web/app/components/workflow/panel/debug-and-preview/hooks.ts
#	web/contract/console/apps.ts
#	web/contract/router.ts
#	web/eslint-suppressions.json
#	web/next.config.ts
#	web/pnpm-lock.yaml
2026-03-23 09:39:49 +08:00

357 lines
13 KiB
Python

from collections.abc import Mapping, Sequence
from typing import Any, cast
from uuid import uuid4
from configs import dify_config
from dify_graph.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
)
from dify_graph.file import File
from dify_graph.model_runtime.entities import PromptMessage
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageRole,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from dify_graph.variables.exc import VariableError
from dify_graph.variables.segments import (
ArrayAnySegment,
ArrayBooleanSegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayPromptMessageSegment,
ArraySegment,
ArrayStringSegment,
BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from dify_graph.variables.types import SegmentType
from dify_graph.variables.variables import (
ArrayAnyVariable,
ArrayBooleanVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayPromptMessageVariable,
ArrayStringVariable,
BooleanVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
ObjectVariable,
SecretVariable,
StringVariable,
VariableBase,
)
class UnsupportedSegmentTypeError(Exception):
pass
class TypeMismatchError(Exception):
pass
# Define the constant
SEGMENT_TO_VARIABLE_MAP = {
ArrayAnySegment: ArrayAnyVariable,
ArrayBooleanSegment: ArrayBooleanVariable,
ArrayFileSegment: ArrayFileVariable,
ArrayNumberSegment: ArrayNumberVariable,
ArrayObjectSegment: ArrayObjectVariable,
ArrayPromptMessageSegment: ArrayPromptMessageVariable,
ArrayStringSegment: ArrayStringVariable,
BooleanSegment: BooleanVariable,
FileSegment: FileVariable,
FloatSegment: FloatVariable,
IntegerSegment: IntegerVariable,
NoneSegment: NoneVariable,
ObjectSegment: ObjectVariable,
StringSegment: StringVariable,
}
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("variable"):
raise VariableError("missing variable")
return mapping["variable"]
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase:
"""
This factory function is used to create the environment variable or the conversation variable,
not support the File type.
"""
if (value_type := mapping.get("value_type")) is None:
raise VariableError("missing value type")
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
result: VariableBase
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
case SegmentType.SECRET:
result = SecretVariable.model_validate(mapping)
case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int):
mapping = dict(mapping)
mapping["value_type"] = SegmentType.INTEGER
result = IntegerVariable.model_validate(mapping)
case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float):
mapping = dict(mapping)
mapping["value_type"] = SegmentType.FLOAT
result = FloatVariable.model_validate(mapping)
case SegmentType.BOOLEAN:
result = BooleanVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f"invalid number value {value}")
case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
result = ArrayStringVariable.model_validate(mapping)
case SegmentType.ARRAY_NUMBER if isinstance(value, list):
result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_BOOLEAN if isinstance(value, list):
result = ArrayBooleanVariable.model_validate(mapping)
case _:
raise VariableError(f"not supported value type {value_type}")
if result.size > dify_config.MAX_VARIABLE_SIZE:
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector:
result = result.model_copy(update={"selector": selector})
return cast(VariableBase, result)
def build_segment(value: Any, /) -> Segment:
# NOTE: If you have runtime type information available, consider using the `build_segment_with_type`
# below
if value is None:
return NoneSegment()
if isinstance(value, Segment):
return value
if isinstance(value, str):
return StringSegment(value=value)
if isinstance(value, bool):
return BooleanSegment(value=value)
if isinstance(value, int):
return IntegerSegment(value=value)
if isinstance(value, float):
return FloatSegment(value=value)
if isinstance(value, dict):
return ObjectSegment(value=value)
if isinstance(value, File):
return FileSegment(value=value)
if isinstance(value, PromptMessage):
# Single PromptMessage should be wrapped in a list
return ArrayPromptMessageSegment(value=[value])
if isinstance(value, list):
# Check if all items are PromptMessage
if value and all(isinstance(item, PromptMessage) for item in value):
return ArrayPromptMessageSegment(value=value)
items = [build_segment(item) for item in value]
types = {item.value_type for item in items}
if all(isinstance(item, ArraySegment) for item in items):
return ArrayAnySegment(value=value)
elif len(types) != 1:
if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}):
return ArrayNumberSegment(value=value)
return ArrayAnySegment(value=value)
match types.pop():
case SegmentType.STRING:
return ArrayStringSegment(value=value)
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return ArrayNumberSegment(value=value)
case SegmentType.BOOLEAN:
return ArrayBooleanSegment(value=value)
case SegmentType.OBJECT:
return ArrayObjectSegment(value=value)
case SegmentType.FILE:
return ArrayFileSegment(value=value)
case SegmentType.NONE:
return ArrayAnySegment(value=value)
case _:
# This should be unreachable.
raise ValueError(f"not supported value {value}")
raise ValueError(f"not supported value {value}")
_segment_factory: Mapping[SegmentType, type[Segment]] = {
SegmentType.NONE: NoneSegment,
SegmentType.STRING: StringSegment,
SegmentType.INTEGER: IntegerSegment,
SegmentType.FLOAT: FloatSegment,
SegmentType.FILE: FileSegment,
SegmentType.BOOLEAN: BooleanSegment,
SegmentType.OBJECT: ObjectSegment,
# Array types
SegmentType.ARRAY_ANY: ArrayAnySegment,
SegmentType.ARRAY_STRING: ArrayStringSegment,
SegmentType.ARRAY_NUMBER: ArrayNumberSegment,
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
SegmentType.ARRAY_FILE: ArrayFileSegment,
SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment,
SegmentType.ARRAY_PROMPT_MESSAGE: ArrayPromptMessageSegment,
}
def _deserialize_prompt_message_list(value: list[dict]) -> list[PromptMessage]:
"""
Deserialize a list of dicts to list[PromptMessage].
This is used when loading ARRAY_PROMPT_MESSAGE from database,
where PromptMessage objects are serialized as dicts.
"""
result: list[PromptMessage] = []
for msg_dict in value:
role = msg_dict.get("role")
if role in (PromptMessageRole.USER, "user"):
result.append(UserPromptMessage.model_validate(msg_dict))
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
result.append(AssistantPromptMessage.model_validate(msg_dict))
elif role in (PromptMessageRole.SYSTEM, "system"):
result.append(SystemPromptMessage.model_validate(msg_dict))
elif role in (PromptMessageRole.TOOL, "tool"):
result.append(ToolPromptMessage.model_validate(msg_dict))
else:
# Fallback to UserPromptMessage for unknown roles
result.append(UserPromptMessage.model_validate(msg_dict))
return result
def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
"""
Build a segment with explicit type checking.
This function creates a segment from a value while enforcing type compatibility
with the specified segment_type. It provides stricter type validation compared
to the standard build_segment function.
Args:
segment_type: The expected SegmentType for the resulting segment
value: The value to be converted into a segment
Returns:
Segment: A segment instance of the appropriate type
Raises:
TypeMismatchError: If the value type doesn't match the expected segment_type
Special Cases:
- For empty list [] values, if segment_type is array[*], returns the corresponding array type
- Type validation is performed before segment creation
Examples:
>>> build_segment_with_type(SegmentType.STRING, "hello")
StringSegment(value="hello")
>>> build_segment_with_type(SegmentType.ARRAY_STRING, [])
ArrayStringSegment(value=[])
>>> build_segment_with_type(SegmentType.STRING, 123)
# Raises TypeMismatchError
"""
# Handle None values
if value is None:
if segment_type == SegmentType.NONE:
return NoneSegment()
else:
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None")
# Handle empty list special case for array types
if isinstance(value, list) and len(value) == 0:
if segment_type == SegmentType.ARRAY_ANY:
return ArrayAnySegment(value=value)
elif segment_type == SegmentType.ARRAY_STRING:
return ArrayStringSegment(value=value)
elif segment_type == SegmentType.ARRAY_BOOLEAN:
return ArrayBooleanSegment(value=value)
elif segment_type == SegmentType.ARRAY_NUMBER:
return ArrayNumberSegment(value=value)
elif segment_type == SegmentType.ARRAY_OBJECT:
return ArrayObjectSegment(value=value)
elif segment_type == SegmentType.ARRAY_FILE:
return ArrayFileSegment(value=value)
else:
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
inferred_type = SegmentType.infer_segment_type(value)
# Type compatibility checking
if inferred_type is None:
raise TypeMismatchError(
f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}"
)
if inferred_type == segment_type:
segment_class = _segment_factory[segment_type]
return segment_class(value_type=segment_type, value=value)
elif segment_type == SegmentType.NUMBER and inferred_type in (
SegmentType.INTEGER,
SegmentType.FLOAT,
):
segment_class = _segment_factory[inferred_type]
return segment_class(value_type=inferred_type, value=value)
elif segment_type == SegmentType.ARRAY_PROMPT_MESSAGE and inferred_type == SegmentType.ARRAY_OBJECT:
# PromptMessage serializes to dict, so ARRAY_OBJECT is compatible with ARRAY_PROMPT_MESSAGE
# Need to deserialize dict list back to PromptMessage objects
deserialized_messages = _deserialize_prompt_message_list(value)
segment_class = _segment_factory[segment_type]
return segment_class(value_type=segment_type, value=deserialized_messages)
else:
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
def segment_to_variable(
*,
segment: Segment,
selector: Sequence[str],
id: str | None = None,
name: str | None = None,
description: str = "",
) -> VariableBase:
if isinstance(segment, VariableBase):
return segment
name = name or selector[-1]
id = id or str(uuid4())
segment_type = type(segment)
if segment_type not in SEGMENT_TO_VARIABLE_MAP:
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return cast(
VariableBase,
variable_class(
id=id,
name=name,
description=description,
value=segment.value,
selector=list(selector),
),
)