refactor: tool entities

This commit is contained in:
Yeuoly
2024-12-13 19:50:54 +08:00
parent 63206a7967
commit 65a4cb769b
17 changed files with 329 additions and 356 deletions

View File

@ -285,28 +285,6 @@ class ToolManager:
else:
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
@classmethod
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict):
"""
init runtime parameter
"""
parameter_value = parameters.get(parameter_rule.name)
if not parameter_value and parameter_value != 0:
# get default value
parameter_value = parameter_rule.default
if not parameter_value and parameter_rule.required:
raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = [x.value for x in parameter_rule.options]
if parameter_value is not None and parameter_value not in options:
raise ValueError(
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
)
return parameter_rule.type.cast_value(parameter_value)
@classmethod
def get_agent_tool_runtime(
cls,
@ -343,7 +321,7 @@ class ToolManager:
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# save tool parameter to tool entity memory
value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
value = parameter.init_frontend_parameter(agent_tool.tool_parameters.get(parameter.name))
runtime_parameters[parameter.name] = value
# decrypt runtime parameters
@ -356,9 +334,6 @@ class ToolManager:
)
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
if not tool_entity.runtime:
raise Exception("tool missing runtime")
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@ -388,7 +363,7 @@ class ToolManager:
for parameter in parameters:
# save tool parameter to tool entity memory
if parameter.form == ToolParameter.ToolParameterForm.FORM:
value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
value = parameter.init_frontend_parameter(workflow_tool.tool_configurations.get(parameter.name))
runtime_parameters[parameter.name] = value
# decrypt runtime parameters
@ -403,9 +378,6 @@ class ToolManager:
if runtime_parameters:
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
if not tool_runtime.runtime:
raise Exception("tool missing runtime")
tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
return tool_runtime
@ -434,12 +406,9 @@ class ToolManager:
for parameter in parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# save tool parameter to tool entity memory
value = cls._init_runtime_parameter(parameter, tool_parameters)
value = parameter.init_frontend_parameter(tool_parameters.get(parameter.name))
runtime_parameters[parameter.name] = value
if not tool_entity.runtime:
raise Exception("tool missing runtime")
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@ -608,9 +577,8 @@ class ToolManager:
tool_provider_id = GenericProviderID(db_provider.provider)
db_provider.provider = tool_provider_id.to_string()
find_db_builtin_provider = lambda provider: next(
(x for x in db_builtin_providers if x.provider == provider), None
)
def find_db_builtin_provider(provider):
return next((x for x in db_builtin_providers if x.provider == provider), None)
# append builtin providers
for provider in builtin_providers: