refactor: tool response to generator

This commit is contained in:
Yeuoly
2024-07-09 15:37:56 +08:00
parent 364df36ac4
commit 563d81277b
15 changed files with 177 additions and 110 deletions

View File

@ -1,5 +1,6 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, Union
from flask import current_app
@ -9,6 +10,7 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolInvokeMessage,
ToolParameter,
ToolProviderCredentials,
ToolProviderType,
@ -24,8 +26,8 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi
logger = logging.getLogger(__name__)
class ToolTransformService:
@staticmethod
def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
@classmethod
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
"""
get tool provider icon url
"""
@ -45,8 +47,8 @@ class ToolTransformService:
return ''
@staticmethod
def repack_provider(provider: Union[dict, UserToolProvider]):
@classmethod
def repack_provider(cls, provider: Union[dict, UserToolProvider]):
"""
repack provider
@ -65,8 +67,9 @@ class ToolTransformService:
icon=provider.icon
)
@staticmethod
@classmethod
def builtin_provider_to_user_provider(
cls,
provider_controller: BuiltinToolProviderController,
db_provider: Optional[BuiltinToolProvider],
decrypt_credentials: bool = True,
@ -126,8 +129,9 @@ class ToolTransformService:
return result
@staticmethod
@classmethod
def api_provider_to_controller(
cls,
db_provider: ApiToolProvider,
) -> ApiToolProviderController:
"""
@ -142,8 +146,9 @@ class ToolTransformService:
return controller
@staticmethod
@classmethod
def workflow_provider_to_controller(
cls,
db_provider: WorkflowToolProvider
) -> WorkflowToolProviderController:
"""
@ -179,8 +184,9 @@ class ToolTransformService:
labels=labels or []
)
@staticmethod
@classmethod
def api_provider_to_user_provider(
cls,
provider_controller: ApiToolProviderController,
db_provider: ApiToolProvider,
decrypt_credentials: bool = True,
@ -231,8 +237,9 @@ class ToolTransformService:
return result
@staticmethod
@classmethod
def tool_to_user_tool(
cls,
tool: Union[ApiToolBundle, WorkflowTool, Tool],
credentials: dict = None,
tenant_id: str = None,
@ -287,4 +294,9 @@ class ToolTransformService:
),
parameters=tool.parameters,
labels=labels
)
)
@classmethod
def transform_messages_to_dict(cls, responses: Generator[ToolInvokeMessage, None, None]):
for response in responses:
yield response.model_dump()