Type phase 3 tool inputs

This commit is contained in:
Yanli 盐粒
2026-03-17 19:31:00 +08:00
parent 79433b0091
commit 61196180b8
5 changed files with 214 additions and 107 deletions

View File

@ -1040,9 +1040,10 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
variable_selector = tool_input.require_variable_selector()
variable = variable_pool.get(variable_selector)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
raise ToolParameterError(f"Variable {variable_selector} does not exist")
parameter_value = variable.value
elif tool_input.type == "constant":
parameter_value = tool_input.value

View File

@ -1,13 +1,24 @@
from enum import IntEnum, StrEnum, auto
from typing import Any, Literal, Union
from __future__ import annotations
from pydantic import BaseModel
from enum import IntEnum, StrEnum, auto
from typing import Literal, TypeAlias
from pydantic import BaseModel, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
AgentInputConstantValue: TypeAlias = (
list[ToolSelector] | str | int | float | bool | dict[str, object] | list[object] | None
)
VariableSelector: TypeAlias = list[str]
_AGENT_INPUT_VALUE_ADAPTER: TypeAdapter[AgentInputConstantValue] = TypeAdapter(AgentInputConstantValue)
_AGENT_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
class AgentNodeData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.AGENT
@ -21,8 +32,20 @@ class AgentNodeData(BaseNodeData):
tool_node_version: str | None = None
class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any]
type: Literal["mixed", "variable", "constant"]
value: AgentInputConstantValue | VariableSelector
@field_validator("value", mode="before")
@classmethod
def validate_value(
cls, value: object, validation_info: ValidationInfo
) -> AgentInputConstantValue | VariableSelector:
input_type = validation_info.data.get("type")
if input_type == "variable":
return _AGENT_VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if input_type in {"mixed", "constant"}:
return _AGENT_INPUT_VALUE_ADAPTER.validate_python(value)
raise ValueError(f"Unknown agent input type: {input_type}")
agent_parameters: dict[str, AgentInput]

View File

@ -1,16 +1,17 @@
from __future__ import annotations
import json
from collections.abc import Sequence
from typing import Any, cast
from collections.abc import Mapping, Sequence
from typing import TypeAlias
from packaging.version import Version
from pydantic import ValidationError
from pydantic import TypeAdapter, ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.app.entities.app_invoke_entities import InvokeFrom
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.plugin.entities.request import InvokeCredentials
@ -28,6 +29,14 @@ from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGen
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
from .strategy_protocols import ResolvedAgentStrategy
JsonObject: TypeAlias = dict[str, object]
JsonObjectList: TypeAlias = list[JsonObject]
VariableSelector: TypeAlias = list[str]
_JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
_JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
_VARIABLE_SELECTOR_ADAPTER = TypeAdapter(VariableSelector)
class AgentRuntimeSupport:
def build_parameters(
@ -39,12 +48,12 @@ class AgentRuntimeSupport:
strategy: ResolvedAgentStrategy,
tenant_id: str,
app_id: str,
invoke_from: Any,
invoke_from: InvokeFrom,
for_log: bool = False,
) -> dict[str, Any]:
) -> dict[str, object]:
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, Any] = {}
result: dict[str, object] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
@ -54,9 +63,10 @@ class AgentRuntimeSupport:
agent_input = node_data.agent_parameters[parameter_name]
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(agent_input.value)
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
raise AgentVariableNotFoundError(str(variable_selector))
parameter_value = variable.value
case "mixed" | "constant":
try:
@ -79,60 +89,38 @@ class AgentRuntimeSupport:
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
value = self._filter_mcp_type_tool(strategy, value)
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
tool_payloads = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
value = self._normalize_tool_payloads(
strategy=strategy,
tools=tool_payloads,
variable_pool=variable_pool,
)
if not for_log:
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
provider_type = self._coerce_tool_provider_type(tool.get("type"))
setting_params = self._coerce_json_object(tool.get("settings")) or {}
parameters = self._coerce_json_object(tool.get("parameters")) or {}
manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params}
provider_id = self._coerce_optional_string(tool.get("provider_name")) or ""
tool_name = self._coerce_optional_string(tool.get("tool_name")) or ""
plugin_unique_identifier = self._coerce_optional_string(tool.get("plugin_unique_identifier"))
credential_id = self._coerce_optional_string(tool.get("credential_id"))
entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""),
provider_id=provider_id,
provider_type=provider_type,
tool_name=tool.get("tool_name", ""),
tool_name=tool_name,
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
plugin_unique_identifier=plugin_unique_identifier,
credential_id=credential_id,
)
extra = tool.get("extra", {})
extra = self._coerce_json_object(tool.get("extra")) or {}
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
@ -145,8 +133,9 @@ class AgentRuntimeSupport:
runtime_variable_pool,
)
if tool_runtime.entity.description:
description_override = self._coerce_optional_string(extra.get("description"))
tool_runtime.entity.description.llm = (
extra.get("description", "") or tool_runtime.entity.description.llm
description_override or tool_runtime.entity.description.llm
)
for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = (
@ -167,13 +156,13 @@ class AgentRuntimeSupport:
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"credential_id": credential_id,
"provider_type": provider_type.value,
}
)
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
value = _JSON_OBJECT_ADAPTER.validate_python(value)
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
history_prompt_messages = []
if node_data.memory:
@ -199,17 +188,27 @@ class AgentRuntimeSupport:
return result
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
def build_credentials(self, *, parameters: Mapping[str, object]) -> InvokeCredentials:
credentials = InvokeCredentials()
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
tools = parameters.get("tools")
if not isinstance(tools, list):
return credentials
for raw_tool in tools:
tool = self._coerce_json_object(raw_tool)
if tool is None:
continue
if not tool.get("credential_id"):
continue
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
except ValidationError:
continue
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
credential_id = self._coerce_optional_string(tool.get("credential_id"))
if credential_id is None:
continue
credentials.tool_credentials[identity.provider] = credential_id
return credentials
def fetch_memory(
@ -232,14 +231,14 @@ class AgentRuntimeSupport:
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
def fetch_model(self, *, tenant_id: str, value: Mapping[str, object]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=value.get("provider", ""),
provider=str(value.get("provider", "")),
model_type=ModelType.LLM,
)
model_name = value.get("model", "")
model_name = str(value.get("model", ""))
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model_name,
@ -249,7 +248,7 @@ class AgentRuntimeSupport:
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model_type=ModelType(str(value.get("model_type", ""))),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
@ -268,9 +267,88 @@ class AgentRuntimeSupport:
@staticmethod
def _filter_mcp_type_tool(
strategy: ResolvedAgentStrategy,
tools: list[dict[str, Any]],
) -> list[dict[str, Any]]:
tools: JsonObjectList,
) -> JsonObjectList:
meta_version = strategy.meta_version
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _normalize_tool_payloads(
self,
*,
strategy: ResolvedAgentStrategy,
tools: JsonObjectList,
variable_pool: VariablePool,
) -> JsonObjectList:
enabled_tools = [dict(tool) for tool in tools if bool(tool.get("enabled", False))]
normalized_tools = self._filter_mcp_type_tool(strategy, enabled_tools)
for tool in normalized_tools:
tool.pop("schemas", None)
tool["parameters"] = self._resolve_tool_parameters(tool=tool, variable_pool=variable_pool)
tool["settings"] = self._resolve_tool_settings(tool)
return normalized_tools
def _resolve_tool_parameters(self, *, tool: Mapping[str, object], variable_pool: VariablePool) -> JsonObject:
parameter_configs = self._coerce_named_json_objects(tool.get("parameters"))
if parameter_configs is None:
raw_parameters = self._coerce_json_object(tool.get("parameters"))
return raw_parameters or {}
resolved_parameters: JsonObject = {}
for key, parameter_config in parameter_configs.items():
if parameter_config.get("auto", ParamsAutoGenerated.OPEN) in (ParamsAutoGenerated.CLOSE, 0):
value_param = self._coerce_json_object(parameter_config.get("value"))
if value_param and value_param.get("type") == "variable":
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(value_param.get("value"))
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
resolved_parameters[key] = variable.value
else:
resolved_parameters[key] = value_param.get("value", "") if value_param is not None else None
else:
resolved_parameters[key] = None
return resolved_parameters
@staticmethod
def _resolve_tool_settings(tool: Mapping[str, object]) -> JsonObject:
settings = AgentRuntimeSupport._coerce_named_json_objects(tool.get("settings"))
if settings is None:
return {}
return {key: setting.get("value") for key, setting in settings.items()}
@staticmethod
def _coerce_json_object(value: object) -> JsonObject | None:
try:
return _JSON_OBJECT_ADAPTER.validate_python(value)
except ValidationError:
return None
@staticmethod
def _coerce_optional_string(value: object) -> str | None:
return value if isinstance(value, str) else None
@staticmethod
def _coerce_tool_provider_type(value: object) -> ToolProviderType:
if isinstance(value, ToolProviderType):
return value
if isinstance(value, str):
return ToolProviderType(value)
return ToolProviderType.BUILT_IN
@classmethod
def _coerce_named_json_objects(cls, value: object) -> dict[str, JsonObject] | None:
if not isinstance(value, dict):
return None
coerced: dict[str, JsonObject] = {}
for key, item in value.items():
if not isinstance(key, str):
return None
json_object = cls._coerce_json_object(item)
if json_object is None:
return None
coerced[key] = json_object
return coerced

View File

@ -1,12 +1,24 @@
from typing import Any, Literal, Union
from __future__ import annotations
from pydantic import BaseModel, field_validator
from typing import Literal, TypeAlias
from pydantic import BaseModel, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
ToolConfigurationValue: TypeAlias = str | int | float | bool
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationValue]
ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None
VariableSelector: TypeAlias = list[str]
_TOOL_CONFIGURATIONS_ADAPTER: TypeAdapter[ToolConfigurations] = TypeAdapter(ToolConfigurations)
_TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str)
_TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue)
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
class ToolEntity(BaseModel):
provider_id: str
@ -14,52 +26,41 @@ class ToolEntity(BaseModel):
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, Any]
tool_configurations: ToolConfigurations
credential_id: str | None = None
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")
@classmethod
def validate_tool_configurations(cls, value, values: ValidationInfo):
if not isinstance(value, dict):
raise ValueError("tool_configurations must be a dictionary")
for key in values.data.get("tool_configurations", {}):
value = values.data.get("tool_configurations", {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f"{key} must be a string")
return value
def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations:
return _TOOL_CONFIGURATIONS_ADAPTER.validate_python(value)
class ToolNodeData(BaseNodeData, ToolEntity):
type: NodeType = BuiltinNodeTypes.TOOL
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
value: ToolInputConstantValue | VariableSelector
@field_validator("type", mode="before")
@field_validator("value", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get("value")
def validate_value(
cls, value: object, validation_info: ValidationInfo
) -> ToolInputConstantValue | VariableSelector:
input_type = validation_info.data.get("type")
if input_type == "mixed":
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
if input_type == "variable":
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if input_type == "constant":
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
raise ValueError(f"Unknown tool input type: {input_type}")
if value is None:
return typ
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))):
raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}")
return typ
def require_variable_selector(self) -> VariableSelector:
if self.type != "variable":
raise ValueError(f"Expected variable tool input, got {self.type}")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
tool_parameters: dict[str, ToolInput]
# The version of the tool parameter.
@ -69,7 +70,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
@field_validator("tool_parameters", mode="before")
@classmethod
def filter_none_tool_inputs(cls, value):
def filter_none_tool_inputs(cls, value: object) -> object:
if not isinstance(value, dict):
return value
@ -80,8 +81,10 @@ class ToolNodeData(BaseNodeData, ToolEntity):
}
@staticmethod
def _has_valid_value(tool_input):
def _has_valid_value(tool_input: object) -> bool:
"""Check if the value is valid"""
if isinstance(tool_input, dict):
return tool_input.get("value") is not None
return getattr(tool_input, "value", None) is not None
if isinstance(tool_input, ToolNodeData.ToolInput):
return tool_input.value is not None
return False

View File

@ -225,10 +225,11 @@ class ToolNode(Node[ToolNodeData]):
continue
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
variable_selector = tool_input.require_variable_selector()
variable = variable_pool.get(variable_selector)
if variable is None:
if parameter.required:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
raise ToolParameterError(f"Variable {variable_selector} does not exist")
continue
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
@ -510,8 +511,9 @@ class ToolNode(Node[ToolNodeData]):
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
variable_selector = input.require_variable_selector()
selector_key = ".".join(variable_selector)
result[f"#{selector_key}#"] = variable_selector
case "constant":
pass