fix: invoke tool streamingly

This commit is contained in:
Yeuoly
2024-08-30 18:11:38 +08:00
parent cf4e9f317e
commit 886a160115
16 changed files with 149 additions and 92 deletions

View File

@ -4,8 +4,8 @@ from typing import Optional, Union
from pydantic import BaseModel, ConfigDict, Field
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
from core.tools.entities.common_entities import I18nObject
from models.provider import ProviderQuotaType
@ -143,7 +143,7 @@ class ProviderConfig(BasicProviderConfig):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
scope: AppSelectorScope | ModelConfigScope | None
scope: AppSelectorScope | ModelConfigScope | None = None
required: bool = False
default: Optional[Union[int, str]] = None
options: Optional[list[Option]] = None

View File

@ -8,6 +8,7 @@ from extensions.ext_redis import redis_client
class ToolProviderCredentialsCacheType(Enum):
PROVIDER = "tool_provider"
ENDPOINT = "endpoint"
class ToolProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):

View File

@ -1,10 +1,11 @@
from typing import Literal, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ProviderConfig, ToolProviderType
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool.tool import ToolParameter
@ -14,7 +15,7 @@ class UserTool(BaseModel):
label: I18nObject # label
description: I18nObject
parameters: Optional[list[ToolParameter]] = None
labels: list[str] = None
labels: list[str] = Field(default_factory=list)
UserToolProviderTypeLiteral = Optional[Literal[
'builtin', 'api', 'workflow'
@ -32,8 +33,8 @@ class UserToolProvider(BaseModel):
original_credentials: Optional[dict] = None
is_team_authorization: bool = False
allow_delete: bool = True
tools: list[UserTool] = None
labels: list[str] = None
tools: list[UserTool] = Field(default_factory=list)
labels: list[str] = Field(default_factory=list)
def to_dict(self) -> dict:
# -------------

View File

@ -25,7 +25,7 @@ class ToolLabelEnum(Enum):
UTILITIES = 'utilities'
OTHER = 'other'
class ToolProviderType(Enum):
class ToolProviderType(str, Enum):
"""
Enum class for tool provider
"""
@ -181,7 +181,7 @@ class ToolParameter(BaseModel):
if options:
option_objs = [ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options]
else:
option_objs = None
option_objs = []
return cls(
name=name,
label=I18nObject(en_US='', zh_Hans=''),

View File

@ -1,21 +1,23 @@
from pydantic import Field
from core.entities.provider_entities import ProviderConfig
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ProviderConfig,
ToolCredentialsOption,
ToolProviderType,
)
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.api_tool import ApiTool
from core.tools.tool.tool import Tool
from extensions.ext_database import db
from models.tools import ApiToolProvider
class ApiToolProviderController(ToolProviderController):
provider_id: str
tenant_id: str
tools: list[ApiTool] = Field(default_factory=list)
@staticmethod
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
@ -25,8 +27,8 @@ class ApiToolProviderController(ToolProviderController):
required=True,
type=ProviderConfig.Type.SELECT,
options=[
ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='')),
ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
ProviderConfig.Option(value='none', label=I18nObject(en_US='None', zh_Hans='')),
ProviderConfig.Option(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
],
default='none',
help=I18nObject(
@ -67,9 +69,9 @@ class ApiToolProviderController(ToolProviderController):
zh_Hans='api key header 的前缀'
),
options=[
ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
ProviderConfig.Option(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
ProviderConfig.Option(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
ProviderConfig.Option(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
]
)
}
@ -96,6 +98,7 @@ class ApiToolProviderController(ToolProviderController):
},
'credentials_schema': credentials_schema,
'provider_id': db_provider.id or '',
'tenant_id': db_provider.tenant_id or '',
})
@property
@ -142,7 +145,7 @@ class ApiToolProviderController(ToolProviderController):
return self.tools
def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]:
def get_tools(self, tenant_id: str) -> list[ApiTool]:
"""
fetch tools from database
@ -153,7 +156,7 @@ class ApiToolProviderController(ToolProviderController):
if self.tools is not None:
return self.tools
tools: list[Tool] = []
tools: list[ApiTool] = []
# get tenant api providers
db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
@ -179,7 +182,7 @@ class ApiToolProviderController(ToolProviderController):
:return: the tool
"""
if self.tools is None:
self.get_tools()
self.get_tools(self.tenant_id)
for tool in self.tools:
if tool.identity.name == tool_name:

View File

@ -39,7 +39,7 @@ class BuiltinToolProviderController(ToolProviderController):
super().__init__(**{
'identity': provider_yaml['identity'],
'credentials_schema': provider_yaml.get('credentials_for_provider', None),
'credentials_schema': provider_yaml.get('credentials_for_provider', {}) or {},
})
def _get_builtin_tools(self) -> list[BuiltinTool]:

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from core.entities.provider_entities import ProviderConfig
from core.tools.entities.tool_entities import (
@ -17,6 +17,8 @@ class ToolProviderController(BaseModel, ABC):
tools: list[Tool] = Field(default_factory=list)
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
model_config = ConfigDict(validate_assignment=True)
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
"""
returns the credentials schema of the provider

View File

@ -206,7 +206,16 @@ class Tool(BaseModel, ABC):
tool_parameters=tool_parameters,
)
return result
if isinstance(result, ToolInvokeMessage):
def single_generator():
yield result
return single_generator()
elif isinstance(result, list):
def generator():
yield from result
return generator()
else:
return result
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
"""
@ -223,7 +232,7 @@ class Tool(BaseModel, ABC):
return result
@abstractmethod
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
pass
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:

View File

@ -116,7 +116,12 @@ class ToolManager:
# decrypt the credentials
credentials = builtin_provider.credentials
controller = cls.get_builtin_provider(provider_id)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=controller.get_credentials_schema(),
provider_type=controller.provider_type.value,
provider_identity=controller.identity.name
)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
@ -135,7 +140,12 @@ class ToolManager:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
# decrypt the credentials
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=api_provider.get_credentials_schema(),
provider_type=api_provider.provider_type.value,
provider_identity=api_provider.identity.name
)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return cast(ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
@ -513,7 +523,12 @@ class ToolManager:
provider_obj, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
)
# init tool configuration
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=controller.get_credentials_schema(),
provider_type=controller.provider_type.value,
provider_identity=controller.identity.name
)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)

View File

@ -1,23 +1,25 @@
from collections.abc import Mapping
from copy import deepcopy
from typing import Any
from pydantic import BaseModel
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.entities.tool_entities import (
ProviderConfig,
ToolParameter,
ToolProviderType,
)
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.tool import Tool
class ToolConfigurationManager(BaseModel):
tenant_id: str
provider_controller: ToolProviderController
config: Mapping[str, BasicProviderConfig]
provider_type: str
provider_identity: str
def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
"""
@ -34,9 +36,9 @@ class ToolConfigurationManager(BaseModel):
credentials = self._deep_copy(credentials)
# get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema()
fields = self.config
for field_name, field in fields.items():
if field.type == ProviderConfig.Type.SECRET_INPUT:
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in credentials:
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
credentials[field_name] = encrypted
@ -52,9 +54,9 @@ class ToolConfigurationManager(BaseModel):
credentials = self._deep_copy(credentials)
# get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema()
fields = self.config
for field_name, field in fields.items():
if field.type == ProviderConfig.Type.SECRET_INPUT:
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in credentials:
if len(credentials[field_name]) > 6:
credentials[field_name] = \
@ -74,7 +76,7 @@ class ToolConfigurationManager(BaseModel):
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
identity_id=f'{self.provider_type}.{self.provider_identity}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cached_credentials = cache.get()
@ -82,9 +84,9 @@ class ToolConfigurationManager(BaseModel):
return cached_credentials
credentials = self._deep_copy(credentials)
# get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema()
fields = self.config
for field_name, field in fields.items():
if field.type == ProviderConfig.Type.SECRET_INPUT:
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in credentials:
try:
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
@ -97,7 +99,7 @@ class ToolConfigurationManager(BaseModel):
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
identity_id=f'{self.provider_type}.{self.provider_identity}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cache.delete()

View File

@ -16,7 +16,7 @@ from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolPro
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
@ -173,7 +173,7 @@ class ApiBasedToolSchemaParser:
return ToolParameter.ToolParameterType.STRING
@staticmethod
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]:
"""
parse openapi yaml to tool bundle
@ -189,7 +189,8 @@ class ApiBasedToolSchemaParser:
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict:
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
warning = warning or {}
"""
parse swagger to openapi
@ -255,7 +256,7 @@ class ApiBasedToolSchemaParser:
return openapi
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]:
"""
parse openapi plugin yaml to tool bundle
@ -287,7 +288,7 @@ class ApiBasedToolSchemaParser:
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning)
@staticmethod
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]:
def auto_parse_to_tool_bundle(content: str, extra_info: dict | None = None, warning: dict | None = None) -> tuple[list[ApiToolBundle], str]:
"""
auto parse to tool bundle

View File

@ -1,6 +1,6 @@
from collections.abc import Generator, Sequence
from os import path
from typing import Any, cast
from typing import Any, Iterable, cast
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
@ -158,14 +158,17 @@ class ToolNode(BaseNode):
tenant_id=self.tenant_id,
conversation_id=None,
)
result = list(messages)
# extract plain text and files
files = self._extract_tool_response_binary(messages)
plain_text = self._extract_tool_response_text(messages)
json = self._extract_tool_response_json(messages)
files = self._extract_tool_response_binary(result)
plain_text = self._extract_tool_response_text(result)
json = self._extract_tool_response_json(result)
return plain_text, files, json
def _extract_tool_response_binary(self, tool_response: Generator[ToolInvokeMessage, None, None]) -> list[FileVar]:
def _extract_tool_response_binary(self, tool_response: Iterable[ToolInvokeMessage]) -> list[FileVar]:
"""
Extract tool response binary
"""
@ -215,7 +218,7 @@ class ToolNode(BaseNode):
return result
def _extract_tool_response_text(self, tool_response: Generator[ToolInvokeMessage]) -> str:
def _extract_tool_response_text(self, tool_response: Iterable[ToolInvokeMessage]) -> str:
"""
Extract tool response text
"""
@ -230,7 +233,7 @@ class ToolNode(BaseNode):
return '\n'.join(result)
def _extract_tool_response_json(self, tool_response: Generator[ToolInvokeMessage]) -> list[dict]:
def _extract_tool_response_json(self, tool_response: Iterable[ToolInvokeMessage]) -> list[dict]:
result: list[dict] = []
for message in tool_response:
if message.type == ToolInvokeMessage.MessageType.JSON: