refactor: tool entities

This commit is contained in:
Yeuoly
2024-12-13 19:50:54 +08:00
parent 63206a7967
commit 65a4cb769b
17 changed files with 329 additions and 356 deletions

View File

@ -117,7 +117,7 @@ class AgentNode(ToolNode):
continue
agent_input = node_data.agent_parameters[parameter_name]
if agent_input.type == "variable":
variable = variable_pool.get(agent_input.value)
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise ValueError(f"Variable {agent_input.value} does not exist")
parameter_value = variable.value

View File

@ -2,6 +2,7 @@ from typing import Any, Literal, Union
from pydantic import BaseModel, ValidationInfo, field_validator
from core.tools.entities.tool_entities import ToolSelector
from core.workflow.nodes.base.entities import BaseNodeData
@ -20,8 +21,21 @@ class AgentEntity(BaseModel):
for key in values.data.get("agent_configurations", {}):
value = values.data.get("agent_configurations", {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f"{key} must be a string")
if isinstance(value, dict):
# convert dict to ToolSelector
return ToolSelector(**value)
elif isinstance(value, ToolSelector):
return value
elif isinstance(value, list):
# convert list[ToolSelector] to ToolSelector
if all(isinstance(val, dict) for val in value):
return [ToolSelector(**val) for val in value]
elif all(isinstance(val, ToolSelector) for val in value):
return value
else:
raise ValueError("value must be a list of ToolSelector")
else:
raise ValueError("value must be a dictionary or ToolSelector")
return value
@ -29,7 +43,7 @@ class AgentEntity(BaseModel):
class AgentNodeData(BaseNodeData, AgentEntity):
class AgentInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
value: Union[list[str], list[ToolSelector], Any]
type: Literal["mixed", "variable", "constant"]
@field_validator("type", mode="before")
@ -45,8 +59,23 @@ class AgentNodeData(BaseNodeData, AgentEntity):
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool):
raise ValueError("value must be a string, int, float, or bool")
elif typ == "constant":
if isinstance(value, list):
# convert dict to ToolSelector
if all(isinstance(val, dict) for val in value):
return value
elif all(isinstance(val, ToolSelector) for val in value):
return value
else:
raise ValueError("value must be a list of ToolSelector")
elif isinstance(value, dict):
# convert dict to ToolSelector
return ToolSelector(**value)
elif isinstance(value, ToolSelector):
return value
else:
raise ValueError("value must be a list of ToolSelector")
return typ
agent_parameters: dict[str, AgentInput]