mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 22:18:15 +08:00
r2
This commit is contained in:
@ -1,221 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import File
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
)
|
||||
|
||||
|
||||
class Datasource(ABC):
|
||||
"""
|
||||
The base class of a datasource
|
||||
"""
|
||||
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(self, entity: DatasourceEntity, runtime: DatasourceRuntime) -> None:
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "Datasource":
|
||||
"""
|
||||
fork a new datasource with metadata
|
||||
:return: the new datasource
|
||||
"""
|
||||
return self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
get the datasource provider type
|
||||
|
||||
:return: the tool provider type
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage]:
|
||||
if self.runtime and self.runtime.runtime_parameters:
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
|
||||
# try parse tool parameters into the correct type
|
||||
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
|
||||
|
||||
result = self._invoke(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
if isinstance(result, ToolInvokeMessage):
|
||||
|
||||
def single_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield result
|
||||
|
||||
return single_generator()
|
||||
elif isinstance(result, list):
|
||||
|
||||
def generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield from result
|
||||
|
||||
return generator()
|
||||
else:
|
||||
return result
|
||||
|
||||
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Transform tool parameters type
|
||||
"""
|
||||
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
|
||||
result = deepcopy(tool_parameters)
|
||||
for parameter in self.entity.parameters or []:
|
||||
if parameter.name in tool_parameters:
|
||||
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
|
||||
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
|
||||
pass
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get the runtime parameters
|
||||
|
||||
interface for developer to dynamic change the parameters of a tool depends on the variables pool
|
||||
|
||||
:return: the runtime parameters
|
||||
"""
|
||||
return self.entity.parameters
|
||||
|
||||
def get_merged_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get merged runtime parameters
|
||||
|
||||
:return: merged runtime parameters
|
||||
"""
|
||||
parameters = self.entity.parameters
|
||||
parameters = parameters.copy()
|
||||
user_parameters = self.get_runtime_parameters() or []
|
||||
user_parameters = user_parameters.copy()
|
||||
|
||||
# override parameters
|
||||
for parameter in user_parameters:
|
||||
# check if parameter in tool parameters
|
||||
for tool_parameter in parameters:
|
||||
if tool_parameter.name == parameter.name:
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
break
|
||||
else:
|
||||
# add new parameter
|
||||
parameters.append(parameter)
|
||||
|
||||
return parameters
|
||||
|
||||
def create_image_message(
|
||||
self,
|
||||
image: str,
|
||||
) -> ToolInvokeMessage:
|
||||
"""
|
||||
create an image message
|
||||
|
||||
:param image: the url of the image
|
||||
:return: the image message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
|
||||
)
|
||||
|
||||
def create_file_message(self, file: "File") -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.FILE,
|
||||
message=ToolInvokeMessage.FileMessage(),
|
||||
meta={"file": file},
|
||||
)
|
||||
|
||||
def create_link_message(self, link: str) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a link message
|
||||
|
||||
:param link: the url of the link
|
||||
:return: the link message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link)
|
||||
)
|
||||
|
||||
def create_text_message(self, text: str) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a text message
|
||||
|
||||
:param text: the text
|
||||
:return: the text message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(text=text),
|
||||
)
|
||||
|
||||
def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a blob message
|
||||
|
||||
:param blob: the blob
|
||||
:param meta: the meta info of blob object
|
||||
:return: the blob message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB,
|
||||
message=ToolInvokeMessage.BlobMessage(blob=blob),
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def create_json_message(self, object: dict) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a json message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
|
||||
)
|
||||
109
api/core/datasource/__base/datasource_plugin.py
Normal file
109
api/core/datasource/__base/datasource_plugin.py
Normal file
@ -0,0 +1,109 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceInvokeMessage,
|
||||
DatasourceParameter,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.plugin.manager.datasource import PluginDatasourceManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
|
||||
|
||||
class DatasourcePlugin:
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
runtime_parameters: Optional[list[DatasourceParameter]]
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceEntity, runtime: DatasourceRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
|
||||
) -> None:
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
self.runtime_parameters = None
|
||||
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
return DatasourceProviderType.RAG_PIPELINE
|
||||
|
||||
def _invoke_first_step(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: dict[str, Any],
|
||||
rag_pipeline_id: Optional[str] = None,
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||
|
||||
yield from manager.invoke_first_step(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
rag_pipeline_id=rag_pipeline_id,
|
||||
)
|
||||
|
||||
def _invoke_second_step(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: dict[str, Any],
|
||||
rag_pipeline_id: Optional[str] = None,
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||
|
||||
yield from manager.invoke(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
rag_pipeline_id=rag_pipeline_id,
|
||||
)
|
||||
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return DatasourcePlugin(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
rag_pipeline_id: Optional[str] = None,
|
||||
) -> list[DatasourceParameter]:
|
||||
"""
|
||||
get the runtime parameters
|
||||
"""
|
||||
if not self.entity.has_runtime_parameters:
|
||||
return self.entity.parameters
|
||||
|
||||
if self.runtime_parameters is not None:
|
||||
return self.runtime_parameters
|
||||
|
||||
manager = PluginDatasourceManager()
|
||||
self.runtime_parameters = manager.get_runtime_parameters(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="",
|
||||
provider=self.entity.identity.provider,
|
||||
datasource=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
rag_pipeline_id=rag_pipeline_id,
|
||||
)
|
||||
|
||||
return self.runtime_parameters
|
||||
145
api/core/datasource/__base/datasource_provider.py
Normal file
145
api/core/datasource/__base/datasource_provider.py
Normal file
@ -0,0 +1,145 @@
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.manager.tool import PluginToolManager
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class DatasourcePluginProviderController(BuiltinToolProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.RAG_PIPELINE
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
manager = PluginToolManager()
|
||||
if not manager.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=self.entity.identity.name,
|
||||
credentials=credentials,
|
||||
):
|
||||
raise ToolProviderCredentialValidationError("Invalid credentials")
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(datasource_entity for datasource_entity in self.entity.datasources if datasource_entity.identity.name == datasource_name), None
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return DatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
def get_datasources(self) -> list[DatasourceTool]: # type: ignore
|
||||
"""
|
||||
get all datasources
|
||||
"""
|
||||
return [
|
||||
DatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
for datasource_entity in self.entity.datasources
|
||||
]
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = dict[str, ProviderConfig]()
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
for credential in self.entity.credentials_schema:
|
||||
credentials_schema[credential.name] = credential
|
||||
|
||||
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||
for credential_name in credentials_schema:
|
||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} not found in provider {self.entity.identity.name}"
|
||||
)
|
||||
|
||||
# check type
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if not credential_schema.required and credentials[credential_name] is None:
|
||||
continue
|
||||
|
||||
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
options = credential_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
|
||||
|
||||
if credentials[credential_name] not in [x.value for x in options]:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} should be one of {options}"
|
||||
)
|
||||
|
||||
credentials_need_to_validate.pop(credential_name)
|
||||
|
||||
for credential_name in credentials_need_to_validate:
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema.required:
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
|
||||
|
||||
# the credential is not set currently, set the default value if needed
|
||||
if credential_schema.default is not None:
|
||||
default_value = credential_schema.default
|
||||
# parse default value into the correct type
|
||||
if credential_schema.type in {
|
||||
ProviderConfig.Type.SECRET_INPUT,
|
||||
ProviderConfig.Type.TEXT_INPUT,
|
||||
ProviderConfig.Type.SELECT,
|
||||
}:
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
||||
@ -1,109 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolProviderEntity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class ToolProviderController(ABC):
|
||||
entity: ToolProviderEntity
|
||||
|
||||
def __init__(self, entity: ToolProviderEntity) -> None:
|
||||
self.entity = entity
|
||||
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
return deepcopy(self.entity.credentials_schema)
|
||||
|
||||
@abstractmethod
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
"""
|
||||
returns a tool that the provider can provide
|
||||
|
||||
:return: tool
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = dict[str, ProviderConfig]()
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
for credential in self.entity.credentials_schema:
|
||||
credentials_schema[credential.name] = credential
|
||||
|
||||
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||
for credential_name in credentials_schema:
|
||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} not found in provider {self.entity.identity.name}"
|
||||
)
|
||||
|
||||
# check type
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if not credential_schema.required and credentials[credential_name] is None:
|
||||
continue
|
||||
|
||||
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
options = credential_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
|
||||
|
||||
if credentials[credential_name] not in [x.value for x in options]:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} should be one of {options}"
|
||||
)
|
||||
|
||||
credentials_need_to_validate.pop(credential_name)
|
||||
|
||||
for credential_name in credentials_need_to_validate:
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema.required:
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
|
||||
|
||||
# the credential is not set currently, set the default value if needed
|
||||
if credential_schema.default is not None:
|
||||
default_value = credential_schema.default
|
||||
# parse default value into the correct type
|
||||
if credential_schema.type in {
|
||||
ProviderConfig.Type.SECRET_INPUT,
|
||||
ProviderConfig.Type.TEXT_INPUT,
|
||||
ProviderConfig.Type.SELECT,
|
||||
}:
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
||||
Reference in New Issue
Block a user