mirror of
https://github.com/langgenius/dify.git
synced 2026-04-23 20:36:14 +08:00
Type phase 3 tool inputs
This commit is contained in:
@ -1040,9 +1040,10 @@ class ToolManager:
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
variable_selector = tool_input.require_variable_selector()
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
raise ToolParameterError(f"Variable {variable_selector} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type == "constant":
|
||||
parameter_value = tool_input.value
|
||||
|
||||
@ -1,13 +1,24 @@
|
||||
from enum import IntEnum, StrEnum, auto
|
||||
from typing import Any, Literal, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
from enum import IntEnum, StrEnum, auto
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
|
||||
AgentInputConstantValue: TypeAlias = (
|
||||
list[ToolSelector] | str | int | float | bool | dict[str, object] | list[object] | None
|
||||
)
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_AGENT_INPUT_VALUE_ADAPTER: TypeAdapter[AgentInputConstantValue] = TypeAdapter(AgentInputConstantValue)
|
||||
_AGENT_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||
|
||||
|
||||
class AgentNodeData(BaseNodeData):
|
||||
type: NodeType = BuiltinNodeTypes.AGENT
|
||||
@ -21,8 +32,20 @@ class AgentNodeData(BaseNodeData):
|
||||
tool_node_version: str | None = None
|
||||
|
||||
class AgentInput(BaseModel):
|
||||
value: Union[list[str], list[ToolSelector], Any]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: AgentInputConstantValue | VariableSelector
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def validate_value(
|
||||
cls, value: object, validation_info: ValidationInfo
|
||||
) -> AgentInputConstantValue | VariableSelector:
|
||||
input_type = validation_info.data.get("type")
|
||||
if input_type == "variable":
|
||||
return _AGENT_VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||
if input_type in {"mixed", "constant"}:
|
||||
return _AGENT_INPUT_VALUE_ADAPTER.validate_python(value)
|
||||
raise ValueError(f"Unknown agent input type: {input_type}")
|
||||
|
||||
agent_parameters: dict[str, AgentInput]
|
||||
|
||||
|
||||
@ -1,16 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TypeAlias
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
@ -28,6 +29,14 @@ from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGen
|
||||
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
|
||||
from .strategy_protocols import ResolvedAgentStrategy
|
||||
|
||||
JsonObject: TypeAlias = dict[str, object]
|
||||
JsonObjectList: TypeAlias = list[JsonObject]
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
|
||||
_JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
|
||||
_VARIABLE_SELECTOR_ADAPTER = TypeAdapter(VariableSelector)
|
||||
|
||||
|
||||
class AgentRuntimeSupport:
|
||||
def build_parameters(
|
||||
@ -39,12 +48,12 @@ class AgentRuntimeSupport:
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
invoke_from: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
) -> dict[str, object]:
|
||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
@ -54,9 +63,10 @@ class AgentRuntimeSupport:
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
|
||||
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(agent_input.value)
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
try:
|
||||
@ -79,60 +89,38 @@ class AgentRuntimeSupport:
|
||||
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
||||
ParamsAutoGenerated.CLOSE,
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
if value_param and value_param.get("type", "") == "variable":
|
||||
variable_selector = value_param.get("value")
|
||||
if not variable_selector:
|
||||
raise ValueError("Variable selector is missing for a variable-type parameter.")
|
||||
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
|
||||
params[key] = variable.value
|
||||
else:
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
tool_payloads = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
|
||||
value = self._normalize_tool_payloads(
|
||||
strategy=strategy,
|
||||
tools=tool_payloads,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
provider_type = self._coerce_tool_provider_type(tool.get("type"))
|
||||
setting_params = self._coerce_json_object(tool.get("settings")) or {}
|
||||
parameters = self._coerce_json_object(tool.get("parameters")) or {}
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
provider_id = self._coerce_optional_string(tool.get("provider_name")) or ""
|
||||
tool_name = self._coerce_optional_string(tool.get("tool_name")) or ""
|
||||
plugin_unique_identifier = self._coerce_optional_string(tool.get("plugin_unique_identifier"))
|
||||
credential_id = self._coerce_optional_string(tool.get("credential_id"))
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_id=provider_id,
|
||||
provider_type=provider_type,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_name=tool_name,
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
credential_id=tool.get("credential_id", None),
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
extra = self._coerce_json_object(tool.get("extra")) or {}
|
||||
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
@ -145,8 +133,9 @@ class AgentRuntimeSupport:
|
||||
runtime_variable_pool,
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
description_override = self._coerce_optional_string(extra.get("description"))
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
||||
description_override or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
@ -167,13 +156,13 @@ class AgentRuntimeSupport:
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"credential_id": credential_id,
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||
value = cast(dict[str, Any], value)
|
||||
value = _JSON_OBJECT_ADAPTER.validate_python(value)
|
||||
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
@ -199,17 +188,27 @@ class AgentRuntimeSupport:
|
||||
|
||||
return result
|
||||
|
||||
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
|
||||
def build_credentials(self, *, parameters: Mapping[str, object]) -> InvokeCredentials:
|
||||
credentials = InvokeCredentials()
|
||||
credentials.tool_credentials = {}
|
||||
for tool in parameters.get("tools", []):
|
||||
tools = parameters.get("tools")
|
||||
if not isinstance(tools, list):
|
||||
return credentials
|
||||
|
||||
for raw_tool in tools:
|
||||
tool = self._coerce_json_object(raw_tool)
|
||||
if tool is None:
|
||||
continue
|
||||
if not tool.get("credential_id"):
|
||||
continue
|
||||
try:
|
||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||
except ValidationError:
|
||||
continue
|
||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||
credential_id = self._coerce_optional_string(tool.get("credential_id"))
|
||||
if credential_id is None:
|
||||
continue
|
||||
credentials.tool_credentials[identity.provider] = credential_id
|
||||
return credentials
|
||||
|
||||
def fetch_memory(
|
||||
@ -232,14 +231,14 @@ class AgentRuntimeSupport:
|
||||
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
def fetch_model(self, *, tenant_id: str, value: Mapping[str, object]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=value.get("provider", ""),
|
||||
provider=str(value.get("provider", "")),
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_name = str(value.get("model", ""))
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=model_name,
|
||||
@ -249,7 +248,7 @@ class AgentRuntimeSupport:
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model_type=ModelType(str(value.get("model_type", ""))),
|
||||
model=model_name,
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
@ -268,9 +267,88 @@ class AgentRuntimeSupport:
|
||||
@staticmethod
|
||||
def _filter_mcp_type_tool(
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tools: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
tools: JsonObjectList,
|
||||
) -> JsonObjectList:
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _normalize_tool_payloads(
|
||||
self,
|
||||
*,
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tools: JsonObjectList,
|
||||
variable_pool: VariablePool,
|
||||
) -> JsonObjectList:
|
||||
enabled_tools = [dict(tool) for tool in tools if bool(tool.get("enabled", False))]
|
||||
normalized_tools = self._filter_mcp_type_tool(strategy, enabled_tools)
|
||||
for tool in normalized_tools:
|
||||
tool.pop("schemas", None)
|
||||
tool["parameters"] = self._resolve_tool_parameters(tool=tool, variable_pool=variable_pool)
|
||||
tool["settings"] = self._resolve_tool_settings(tool)
|
||||
return normalized_tools
|
||||
|
||||
def _resolve_tool_parameters(self, *, tool: Mapping[str, object], variable_pool: VariablePool) -> JsonObject:
|
||||
parameter_configs = self._coerce_named_json_objects(tool.get("parameters"))
|
||||
if parameter_configs is None:
|
||||
raw_parameters = self._coerce_json_object(tool.get("parameters"))
|
||||
return raw_parameters or {}
|
||||
|
||||
resolved_parameters: JsonObject = {}
|
||||
for key, parameter_config in parameter_configs.items():
|
||||
if parameter_config.get("auto", ParamsAutoGenerated.OPEN) in (ParamsAutoGenerated.CLOSE, 0):
|
||||
value_param = self._coerce_json_object(parameter_config.get("value"))
|
||||
if value_param and value_param.get("type") == "variable":
|
||||
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(value_param.get("value"))
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
resolved_parameters[key] = variable.value
|
||||
else:
|
||||
resolved_parameters[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
resolved_parameters[key] = None
|
||||
|
||||
return resolved_parameters
|
||||
|
||||
@staticmethod
|
||||
def _resolve_tool_settings(tool: Mapping[str, object]) -> JsonObject:
|
||||
settings = AgentRuntimeSupport._coerce_named_json_objects(tool.get("settings"))
|
||||
if settings is None:
|
||||
return {}
|
||||
return {key: setting.get("value") for key, setting in settings.items()}
|
||||
|
||||
@staticmethod
|
||||
def _coerce_json_object(value: object) -> JsonObject | None:
|
||||
try:
|
||||
return _JSON_OBJECT_ADAPTER.validate_python(value)
|
||||
except ValidationError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _coerce_optional_string(value: object) -> str | None:
|
||||
return value if isinstance(value, str) else None
|
||||
|
||||
@staticmethod
|
||||
def _coerce_tool_provider_type(value: object) -> ToolProviderType:
|
||||
if isinstance(value, ToolProviderType):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return ToolProviderType(value)
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
@classmethod
|
||||
def _coerce_named_json_objects(cls, value: object) -> dict[str, JsonObject] | None:
|
||||
if not isinstance(value, dict):
|
||||
return None
|
||||
|
||||
coerced: dict[str, JsonObject] = {}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str):
|
||||
return None
|
||||
json_object = cls._coerce_json_object(item)
|
||||
if json_object is None:
|
||||
return None
|
||||
coerced[key] = json_object
|
||||
return coerced
|
||||
|
||||
@ -1,12 +1,24 @@
|
||||
from typing import Any, Literal, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
|
||||
ToolConfigurationValue: TypeAlias = str | int | float | bool
|
||||
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationValue]
|
||||
ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_TOOL_CONFIGURATIONS_ADAPTER: TypeAdapter[ToolConfigurations] = TypeAdapter(ToolConfigurations)
|
||||
_TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str)
|
||||
_TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue)
|
||||
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
@ -14,52 +26,41 @@ class ToolEntity(BaseModel):
|
||||
provider_name: str # redundancy
|
||||
tool_name: str
|
||||
tool_label: str # redundancy
|
||||
tool_configurations: dict[str, Any]
|
||||
tool_configurations: ToolConfigurations
|
||||
credential_id: str | None = None
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
@field_validator("tool_configurations", mode="before")
|
||||
@classmethod
|
||||
def validate_tool_configurations(cls, value, values: ValidationInfo):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("tool_configurations must be a dictionary")
|
||||
|
||||
for key in values.data.get("tool_configurations", {}):
|
||||
value = values.data.get("tool_configurations", {}).get(key)
|
||||
if not isinstance(value, str | int | float | bool):
|
||||
raise ValueError(f"{key} must be a string")
|
||||
|
||||
return value
|
||||
def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations:
|
||||
return _TOOL_CONFIGURATIONS_ADAPTER.validate_python(value)
|
||||
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
type: NodeType = BuiltinNodeTypes.TOOL
|
||||
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: ToolInputConstantValue | VariableSelector
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = validation_info.data.get("value")
|
||||
def validate_value(
|
||||
cls, value: object, validation_info: ValidationInfo
|
||||
) -> ToolInputConstantValue | VariableSelector:
|
||||
input_type = validation_info.data.get("type")
|
||||
if input_type == "mixed":
|
||||
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
|
||||
if input_type == "variable":
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||
if input_type == "constant":
|
||||
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
|
||||
raise ValueError(f"Unknown tool input type: {input_type}")
|
||||
|
||||
if value is None:
|
||||
return typ
|
||||
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
for val in value:
|
||||
if not isinstance(val, str):
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))):
|
||||
raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}")
|
||||
return typ
|
||||
def require_variable_selector(self) -> VariableSelector:
|
||||
if self.type != "variable":
|
||||
raise ValueError(f"Expected variable tool input, got {self.type}")
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
|
||||
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
# The version of the tool parameter.
|
||||
@ -69,7 +70,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
|
||||
@field_validator("tool_parameters", mode="before")
|
||||
@classmethod
|
||||
def filter_none_tool_inputs(cls, value):
|
||||
def filter_none_tool_inputs(cls, value: object) -> object:
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
@ -80,8 +81,10 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _has_valid_value(tool_input):
|
||||
def _has_valid_value(tool_input: object) -> bool:
|
||||
"""Check if the value is valid"""
|
||||
if isinstance(tool_input, dict):
|
||||
return tool_input.get("value") is not None
|
||||
return getattr(tool_input, "value", None) is not None
|
||||
if isinstance(tool_input, ToolNodeData.ToolInput):
|
||||
return tool_input.value is not None
|
||||
return False
|
||||
|
||||
@ -225,10 +225,11 @@ class ToolNode(Node[ToolNodeData]):
|
||||
continue
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
variable_selector = tool_input.require_variable_selector()
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
if parameter.required:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
raise ToolParameterError(f"Variable {variable_selector} does not exist")
|
||||
continue
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
@ -510,8 +511,9 @@ class ToolNode(Node[ToolNodeData]):
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
variable_selector = input.require_variable_selector()
|
||||
selector_key = ".".join(variable_selector)
|
||||
result[f"#{selector_key}#"] = variable_selector
|
||||
case "constant":
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user