refactor: tool

This commit is contained in:
Yeuoly
2024-09-20 23:48:48 +08:00
parent 3c1d32e3ac
commit 91cb80f795
29 changed files with 498 additions and 906 deletions

View File

@ -211,7 +211,9 @@ class WorkflowToolManageService:
ToolTransformService.repack_provider(user_tool_provider)
user_tool_provider.tools = [
ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, [])
tool=tool.get_tools(user_id, tenant_id)[0],
labels=labels.get(tool.provider_id, []),
tenant_id=tenant_id,
)
]
result.append(user_tool_provider)
@ -248,7 +250,7 @@ class WorkflowToolManageService:
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
return cls._get_workflow_tool(db_tool)
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
@ -264,10 +266,10 @@ class WorkflowToolManageService:
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first()
)
return cls._get_workflow_tool(db_tool)
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def _get_workflow_tool(cls, db_tool: WorkflowToolProvider | None):
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
"""
Get a workflow tool.
:db_tool: the database tool
@ -298,7 +300,9 @@ class WorkflowToolManageService:
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool(
tool.get_tools(db_tool.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
tool=tool.get_tools(db_tool.tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool),
tenant_id=tenant_id,
),
"synced": workflow.version == db_tool.version,
"privacy_policy": db_tool.privacy_policy,
@ -326,6 +330,8 @@ class WorkflowToolManageService:
return [
ToolTransformService.tool_to_user_tool(
tool=tool.get_tools(db_tool.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
tool=tool.get_tools(db_tool.tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool),
tenant_id=tenant_id,
)
]