mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 05:58:14 +08:00
Merge remote-tracking branch 'origin/feat/workflow' into feat/workflow
This commit is contained in:
@ -11,6 +11,7 @@ class ToolProviderType(Enum):
|
||||
Enum class for tool provider
|
||||
"""
|
||||
BUILT_IN = "built-in"
|
||||
DATASET_RETRIEVAL = "dataset-retrieval"
|
||||
APP_BASED = "app-based"
|
||||
API_BASED = "api-based"
|
||||
|
||||
@ -161,6 +162,8 @@ class ToolIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
provider: str = Field(..., description="The provider of the tool")
|
||||
icon: Optional[str] = None
|
||||
|
||||
class ToolCredentialsOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
@ -334,23 +337,25 @@ class ToolInvokeMeta(BaseModel):
|
||||
"""
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: Optional[str] = None
|
||||
tool_config: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> 'ToolInvokeMeta':
|
||||
"""
|
||||
Get an empty instance of ToolInvokeMeta
|
||||
"""
|
||||
return cls(time_cost=0.0, error=None)
|
||||
return cls(time_cost=0.0, error=None, tool_config={})
|
||||
|
||||
@classmethod
|
||||
def error_instance(cls, error: str) -> 'ToolInvokeMeta':
|
||||
"""
|
||||
Get an instance of ToolInvokeMeta with error
|
||||
"""
|
||||
return cls(time_cost=0.0, error=error)
|
||||
return cls(time_cost=0.0, error=error, tool_config={})
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'time_cost': self.time_cost,
|
||||
'error': self.error,
|
||||
'tool_config': self.tool_config,
|
||||
}
|
||||
@ -16,6 +16,8 @@ from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class ApiBasedToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController':
|
||||
credentials_schema = {
|
||||
@ -89,9 +91,10 @@ class ApiBasedToolProviderController(ToolProviderController):
|
||||
'en_US': db_provider.description,
|
||||
'zh_Hans': db_provider.description
|
||||
},
|
||||
'icon': db_provider.icon
|
||||
'icon': db_provider.icon,
|
||||
},
|
||||
'credentials_schema': credentials_schema
|
||||
'credentials_schema': credentials_schema,
|
||||
'provider_id': db_provider.id,
|
||||
})
|
||||
|
||||
@property
|
||||
@ -120,7 +123,8 @@ class ApiBasedToolProviderController(ToolProviderController):
|
||||
'en_US': tool_bundle.operation_id,
|
||||
'zh_Hans': tool_bundle.operation_id
|
||||
},
|
||||
'icon': tool_bundle.icon if tool_bundle.icon else ''
|
||||
'icon': self.identity.icon,
|
||||
'provider': self.provider_id,
|
||||
},
|
||||
'description': {
|
||||
'human': {
|
||||
|
||||
@ -68,6 +68,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'builtin', provider, 'tools', f'{tool_name}.py'),
|
||||
parent_type=BuiltinTool)
|
||||
tool["identity"]["provider"] = provider
|
||||
tools.append(assistant_tool_class(**tool))
|
||||
|
||||
self.tools = tools
|
||||
|
||||
@ -8,7 +8,8 @@ import requests
|
||||
|
||||
import core.helper.ssrf_proxy as ssrf_proxy
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
@ -34,7 +35,7 @@ class ApiTool(Tool):
|
||||
api_bundle=self.api_bundle.copy() if self.api_bundle else None,
|
||||
runtime=Tool.Runtime(**meta)
|
||||
)
|
||||
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str:
|
||||
"""
|
||||
validate the credentials for Api tool
|
||||
@ -49,6 +50,9 @@ class ApiTool(Tool):
|
||||
# validate response
|
||||
return self.validate_and_parse_response(response)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return UserToolProvider.ProviderType.API
|
||||
|
||||
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
headers = {}
|
||||
credentials = self.runtime.credentials or {}
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.model.tool_model_manager import ToolModelManager
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.web_reader_tool import get_url
|
||||
@ -40,6 +42,9 @@ class BuiltinTool(Tool):
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return UserToolProvider.ProviderType.BUILTIN
|
||||
|
||||
def get_max_tokens(self) -> int:
|
||||
"""
|
||||
get max tokens
|
||||
|
||||
@ -7,7 +7,13 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
|
||||
@ -53,7 +59,7 @@ class DatasetRetrieverTool(Tool):
|
||||
for langchain_tool in langchain_tools:
|
||||
tool = DatasetRetrieverTool(
|
||||
langchain_tool=langchain_tool,
|
||||
identity=ToolIdentity(author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
|
||||
identity=ToolIdentity(provider='', author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
|
||||
parameters=[],
|
||||
is_team_authorization=True,
|
||||
description=ToolDescription(
|
||||
@ -77,6 +83,9 @@ class DatasetRetrieverTool(Tool):
|
||||
required=True,
|
||||
default=''),
|
||||
]
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.DATASET_RETRIEVAL
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
|
||||
@ -11,7 +11,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage
|
||||
from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
VISION_PROMPT = """## Image Recognition Task
|
||||
@ -79,6 +79,9 @@ class ModelTool(Tool):
|
||||
"""
|
||||
pass
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
"""
|
||||
|
||||
@ -9,6 +9,7 @@ from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
ToolRuntimeImageVariable,
|
||||
ToolRuntimeVariable,
|
||||
ToolRuntimeVariablePool,
|
||||
@ -59,6 +60,14 @@ class Tool(BaseModel, ABC):
|
||||
runtime=Tool.Runtime(**meta),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
get the tool provider type
|
||||
|
||||
:return: the tool provider type
|
||||
"""
|
||||
|
||||
def load_variables(self, variables: ToolRuntimeVariablePool):
|
||||
"""
|
||||
load variables from database
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from typing import Union
|
||||
|
||||
@ -55,11 +56,7 @@ class ToolEngine:
|
||||
tool_inputs=tool_parameters
|
||||
)
|
||||
|
||||
try:
|
||||
meta, response = ToolEngine._invoke(tool, tool_parameters, user_id)
|
||||
except ToolEngineInvokeError as e:
|
||||
meta = e.meta
|
||||
|
||||
meta, response = ToolEngine._invoke(tool, tool_parameters, user_id)
|
||||
response = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=response,
|
||||
user_id=user_id,
|
||||
@ -104,11 +101,16 @@ class ToolEngine:
|
||||
except ToolInvokeError as e:
|
||||
error_response = f"tool invoke error: {e}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except ToolEngineInvokeError as e:
|
||||
meta = e.args[0]
|
||||
error_response = f"tool invoke error: {meta.error}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
return error_response, [], meta
|
||||
except Exception as e:
|
||||
error_response = f"unknown error: {e}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
|
||||
return error_response, [], meta
|
||||
return error_response, [], ToolInvokeMeta.error_instance(error_response)
|
||||
|
||||
@staticmethod
|
||||
def workflow_invoke(tool: Tool, tool_parameters: dict,
|
||||
@ -146,12 +148,18 @@ class ToolEngine:
|
||||
Invoke the tool with the given arguments.
|
||||
"""
|
||||
started_at = datetime.now(timezone.utc)
|
||||
meta = ToolInvokeMeta(time_cost=0.0, error=None)
|
||||
meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={
|
||||
'tool_name': tool.identity.name,
|
||||
'tool_provider': tool.identity.provider,
|
||||
'tool_provider_type': tool.tool_provider_type().value,
|
||||
'tool_parameters': deepcopy(tool.runtime.runtime_parameters),
|
||||
'tool_icon': tool.identity.icon
|
||||
})
|
||||
try:
|
||||
response = tool.invoke(user_id, tool_parameters)
|
||||
except Exception as e:
|
||||
meta.error = str(e)
|
||||
raise ToolEngineInvokeError(meta=meta)
|
||||
raise ToolEngineInvokeError(meta)
|
||||
finally:
|
||||
ended_at = datetime.now(timezone.utc)
|
||||
meta.time_cost = (ended_at - started_at).total_seconds()
|
||||
|
||||
Reference in New Issue
Block a user