mirror of
https://github.com/langgenius/dify.git
synced 2026-03-04 23:36:20 +08:00
130 lines
4.1 KiB
Python
130 lines
4.1 KiB
Python
from typing import Annotated, Any, Literal
|
|
|
|
from pydantic import (
|
|
BaseModel,
|
|
BeforeValidator,
|
|
Field,
|
|
field_validator,
|
|
)
|
|
|
|
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
|
from dify_graph.nodes.base import BaseNodeData
|
|
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
|
|
from dify_graph.variables.types import SegmentType
|
|
|
|
_OLD_BOOL_TYPE_NAME = "bool"
|
|
_OLD_SELECT_TYPE_NAME = "select"
|
|
|
|
_VALID_PARAMETER_TYPES = frozenset(
|
|
[
|
|
SegmentType.STRING, # "string",
|
|
SegmentType.NUMBER, # "number",
|
|
SegmentType.BOOLEAN,
|
|
SegmentType.ARRAY_STRING,
|
|
SegmentType.ARRAY_NUMBER,
|
|
SegmentType.ARRAY_OBJECT,
|
|
SegmentType.ARRAY_BOOLEAN,
|
|
_OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node
|
|
_OLD_SELECT_TYPE_NAME, # string type with enumeration choices.
|
|
]
|
|
)
|
|
|
|
|
|
def _validate_type(parameter_type: str) -> SegmentType:
|
|
if parameter_type not in _VALID_PARAMETER_TYPES:
|
|
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
|
|
|
|
if parameter_type == _OLD_BOOL_TYPE_NAME:
|
|
return SegmentType.BOOLEAN
|
|
elif parameter_type == _OLD_SELECT_TYPE_NAME:
|
|
return SegmentType.STRING
|
|
return SegmentType(parameter_type)
|
|
|
|
|
|
class ParameterConfig(BaseModel):
|
|
"""
|
|
Parameter Config.
|
|
"""
|
|
|
|
name: str
|
|
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
|
|
options: list[str] | None = None
|
|
description: str
|
|
required: bool
|
|
|
|
@field_validator("name", mode="before")
|
|
@classmethod
|
|
def validate_name(cls, value) -> str:
|
|
if not value:
|
|
raise ValueError("Parameter name is required")
|
|
if value in {"__reason", "__is_success"}:
|
|
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
|
|
return str(value)
|
|
|
|
def is_array_type(self) -> bool:
|
|
return self.type.is_array_type()
|
|
|
|
def element_type(self) -> SegmentType:
|
|
"""Return the element type of the parameter.
|
|
|
|
Raises a ValueError if the parameter's type is not an array type.
|
|
"""
|
|
element_type = self.type.element_type()
|
|
# At this point, self.type is guaranteed to be one of `ARRAY_STRING`,
|
|
# `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`.
|
|
#
|
|
# See: _VALID_PARAMETER_TYPES for reference.
|
|
assert element_type is not None, f"the element type should not be None, {self.type=}"
|
|
return element_type
|
|
|
|
|
|
class ParameterExtractorNodeData(BaseNodeData):
|
|
"""
|
|
Parameter Extractor Node Data.
|
|
"""
|
|
|
|
model: ModelConfig
|
|
query: list[str]
|
|
parameters: list[ParameterConfig]
|
|
instruction: str | None = None
|
|
memory: MemoryConfig | None = None
|
|
reasoning_mode: Literal["function_call", "prompt"]
|
|
vision: VisionConfig = Field(default_factory=VisionConfig)
|
|
|
|
@field_validator("reasoning_mode", mode="before")
|
|
@classmethod
|
|
def set_reasoning_mode(cls, v) -> str:
|
|
return v or "function_call"
|
|
|
|
def get_parameter_json_schema(self):
|
|
"""
|
|
Get parameter json schema.
|
|
|
|
:return: parameter json schema
|
|
"""
|
|
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
|
|
|
for parameter in self.parameters:
|
|
parameter_schema: dict[str, Any] = {"description": parameter.description}
|
|
|
|
if parameter.type == SegmentType.STRING:
|
|
parameter_schema["type"] = "string"
|
|
elif parameter.type.is_array_type():
|
|
parameter_schema["type"] = "array"
|
|
element_type = parameter.type.element_type()
|
|
if element_type is None:
|
|
raise AssertionError("element type should not be None.")
|
|
parameter_schema["items"] = {"type": element_type.value}
|
|
else:
|
|
parameter_schema["type"] = parameter.type
|
|
|
|
if parameter.options:
|
|
parameter_schema["enum"] = parameter.options
|
|
|
|
parameters["properties"][parameter.name] = parameter_schema
|
|
|
|
if parameter.required:
|
|
parameters["required"].append(parameter.name)
|
|
|
|
return parameters
|