Merge remote-tracking branch 'origin/feat/workflow' into feat/workflow

This commit is contained in:
jyong
2024-03-29 19:29:42 +08:00
21 changed files with 601 additions and 603 deletions

View File

@ -11,6 +11,7 @@ class ToolProviderType(Enum):
Enum class for tool provider
"""
BUILT_IN = "built-in"
DATASET_RETRIEVAL = "dataset-retrieval"
APP_BASED = "app-based"
API_BASED = "api-based"
@ -161,6 +162,8 @@ class ToolIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
label: I18nObject = Field(..., description="The label of the tool")
provider: str = Field(..., description="The provider of the tool")
icon: Optional[str] = None
class ToolCredentialsOption(BaseModel):
value: str = Field(..., description="The value of the option")
@ -334,23 +337,25 @@ class ToolInvokeMeta(BaseModel):
"""
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: Optional[str] = None
tool_config: Optional[dict] = None
@classmethod
def empty(cls) -> 'ToolInvokeMeta':
"""
Get an empty instance of ToolInvokeMeta
"""
return cls(time_cost=0.0, error=None)
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> 'ToolInvokeMeta':
"""
Get an instance of ToolInvokeMeta with error
"""
return cls(time_cost=0.0, error=error)
return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self) -> dict:
return {
'time_cost': self.time_cost,
'error': self.error,
'tool_config': self.tool_config,
}

View File

@ -16,6 +16,8 @@ from models.tools import ApiToolProvider
class ApiBasedToolProviderController(ToolProviderController):
provider_id: str
@staticmethod
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController':
credentials_schema = {
@ -89,9 +91,10 @@ class ApiBasedToolProviderController(ToolProviderController):
'en_US': db_provider.description,
'zh_Hans': db_provider.description
},
'icon': db_provider.icon
'icon': db_provider.icon,
},
'credentials_schema': credentials_schema
'credentials_schema': credentials_schema,
'provider_id': db_provider.id,
})
@property
@ -120,7 +123,8 @@ class ApiBasedToolProviderController(ToolProviderController):
'en_US': tool_bundle.operation_id,
'zh_Hans': tool_bundle.operation_id
},
'icon': tool_bundle.icon if tool_bundle.icon else ''
'icon': self.identity.icon,
'provider': self.provider_id,
},
'description': {
'human': {

View File

@ -68,6 +68,7 @@ class BuiltinToolProviderController(ToolProviderController):
script_path=path.join(path.dirname(path.realpath(__file__)),
'builtin', provider, 'tools', f'{tool_name}.py'),
parent_type=BuiltinTool)
tool["identity"]["provider"] = provider
tools.append(assistant_tool_class(**tool))
self.tools = tools

View File

@ -8,7 +8,8 @@ import requests
import core.helper.ssrf_proxy as ssrf_proxy
from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.entities.user_entities import UserToolProvider
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.tool.tool import Tool
@ -34,7 +35,7 @@ class ApiTool(Tool):
api_bundle=self.api_bundle.copy() if self.api_bundle else None,
runtime=Tool.Runtime(**meta)
)
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str:
"""
validate the credentials for Api tool
@ -49,6 +50,9 @@ class ApiTool(Tool):
# validate response
return self.validate_and_parse_response(response)
def tool_provider_type(self) -> ToolProviderType:
return UserToolProvider.ProviderType.API
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
headers = {}
credentials = self.runtime.credentials or {}

View File

@ -1,6 +1,8 @@
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.user_entities import UserToolProvider
from core.tools.model.tool_model_manager import ToolModelManager
from core.tools.tool.tool import Tool
from core.tools.utils.web_reader_tool import get_url
@ -40,6 +42,9 @@ class BuiltinTool(Tool):
prompt_messages=prompt_messages,
)
def tool_provider_type(self) -> ToolProviderType:
return UserToolProvider.ProviderType.BUILTIN
def get_max_tokens(self) -> int:
"""
get max tokens

View File

@ -7,7 +7,13 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
from core.tools.entities.tool_entities import (
ToolDescription,
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.tool.tool import Tool
@ -53,7 +59,7 @@ class DatasetRetrieverTool(Tool):
for langchain_tool in langchain_tools:
tool = DatasetRetrieverTool(
langchain_tool=langchain_tool,
identity=ToolIdentity(author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
identity=ToolIdentity(provider='', author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
parameters=[],
is_team_authorization=True,
description=ToolDescription(
@ -77,6 +83,9 @@ class DatasetRetrieverTool(Tool):
required=True,
default=''),
]
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.DATASET_RETRIEVAL
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""

View File

@ -11,7 +11,7 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage
from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage, ToolProviderType
from core.tools.tool.tool import Tool
VISION_PROMPT = """## Image Recognition Task
@ -79,6 +79,9 @@ class ModelTool(Tool):
"""
pass
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.BUILT_IN
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
"""

View File

@ -9,6 +9,7 @@ from core.tools.entities.tool_entities import (
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
ToolRuntimeImageVariable,
ToolRuntimeVariable,
ToolRuntimeVariablePool,
@ -59,6 +60,14 @@ class Tool(BaseModel, ABC):
runtime=Tool.Runtime(**meta),
)
@abstractmethod
def tool_provider_type(self) -> ToolProviderType:
"""
get the tool provider type
:return: the tool provider type
"""
def load_variables(self, variables: ToolRuntimeVariablePool):
"""
load variables from database

View File

@ -1,3 +1,4 @@
from copy import deepcopy
from datetime import datetime, timezone
from typing import Union
@ -55,11 +56,7 @@ class ToolEngine:
tool_inputs=tool_parameters
)
try:
meta, response = ToolEngine._invoke(tool, tool_parameters, user_id)
except ToolEngineInvokeError as e:
meta = e.meta
meta, response = ToolEngine._invoke(tool, tool_parameters, user_id)
response = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=response,
user_id=user_id,
@ -104,11 +101,16 @@ class ToolEngine:
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
agent_tool_callback.on_tool_error(e)
except ToolEngineInvokeError as e:
meta = e.args[0]
error_response = f"tool invoke error: {meta.error}"
agent_tool_callback.on_tool_error(e)
return error_response, [], meta
except Exception as e:
error_response = f"unknown error: {e}"
agent_tool_callback.on_tool_error(e)
return error_response, [], meta
return error_response, [], ToolInvokeMeta.error_instance(error_response)
@staticmethod
def workflow_invoke(tool: Tool, tool_parameters: dict,
@ -146,12 +148,18 @@ class ToolEngine:
Invoke the tool with the given arguments.
"""
started_at = datetime.now(timezone.utc)
meta = ToolInvokeMeta(time_cost=0.0, error=None)
meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={
'tool_name': tool.identity.name,
'tool_provider': tool.identity.provider,
'tool_provider_type': tool.tool_provider_type().value,
'tool_parameters': deepcopy(tool.runtime.runtime_parameters),
'tool_icon': tool.identity.icon
})
try:
response = tool.invoke(user_id, tool_parameters)
except Exception as e:
meta.error = str(e)
raise ToolEngineInvokeError(meta=meta)
raise ToolEngineInvokeError(meta)
finally:
ended_at = datetime.now(timezone.utc)
meta.time_cost = (ended_at - started_at).total_seconds()