mirror of
https://github.com/langgenius/dify.git
synced 2026-04-30 23:48:04 +08:00
Merge branch 'main' into feat/r2
This commit is contained in:
@ -18,10 +18,15 @@ class DatasourcePlugin:
|
||||
plugin_unique_identifier: str
|
||||
runtime_parameters: Optional[list[DatasourceParameter]]
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceEntity, runtime: DatasourceRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
@ -73,7 +78,6 @@ class DatasourcePlugin:
|
||||
rag_pipeline_id=rag_pipeline_id,
|
||||
)
|
||||
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return DatasourcePlugin(
|
||||
entity=self.entity,
|
||||
|
||||
@ -50,7 +50,12 @@ class DatasourcePluginProviderController(BuiltinToolProviderController):
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(datasource_entity for datasource_entity in self.entity.datasources if datasource_entity.identity.name == datasource_name), None
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
@ -78,68 +83,68 @@ class DatasourcePluginProviderController(BuiltinToolProviderController):
|
||||
)
|
||||
for datasource_entity in self.entity.datasources
|
||||
]
|
||||
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = dict[str, ProviderConfig]()
|
||||
if credentials_schema is None:
|
||||
return
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = dict[str, ProviderConfig]()
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
for credential in self.entity.credentials_schema:
|
||||
credentials_schema[credential.name] = credential
|
||||
for credential in self.entity.credentials_schema:
|
||||
credentials_schema[credential.name] = credential
|
||||
|
||||
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||
for credential_name in credentials_schema:
|
||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||
for credential_name in credentials_schema:
|
||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} not found in provider {self.entity.identity.name}"
|
||||
)
|
||||
|
||||
# check type
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if not credential_schema.required and credentials[credential_name] is None:
|
||||
continue
|
||||
|
||||
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
options = credential_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
|
||||
|
||||
if credentials[credential_name] not in [x.value for x in options]:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} not found in provider {self.entity.identity.name}"
|
||||
f"credential {credential_name} should be one of {options}"
|
||||
)
|
||||
|
||||
# check type
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if not credential_schema.required and credentials[credential_name] is None:
|
||||
continue
|
||||
credentials_need_to_validate.pop(credential_name)
|
||||
|
||||
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
for credential_name in credentials_need_to_validate:
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema.required:
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
|
||||
|
||||
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
# the credential is not set currently, set the default value if needed
|
||||
if credential_schema.default is not None:
|
||||
default_value = credential_schema.default
|
||||
# parse default value into the correct type
|
||||
if credential_schema.type in {
|
||||
ProviderConfig.Type.SECRET_INPUT,
|
||||
ProviderConfig.Type.TEXT_INPUT,
|
||||
ProviderConfig.Type.SELECT,
|
||||
}:
|
||||
default_value = str(default_value)
|
||||
|
||||
options = credential_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
|
||||
|
||||
if credentials[credential_name] not in [x.value for x in options]:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} should be one of {options}"
|
||||
)
|
||||
|
||||
credentials_need_to_validate.pop(credential_name)
|
||||
|
||||
for credential_name in credentials_need_to_validate:
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema.required:
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
|
||||
|
||||
# the credential is not set currently, set the default value if needed
|
||||
if credential_schema.default is not None:
|
||||
default_value = credential_schema.default
|
||||
# parse default value into the correct type
|
||||
if credential_schema.type in {
|
||||
ProviderConfig.Type.SECRET_INPUT,
|
||||
ProviderConfig.Type.TEXT_INPUT,
|
||||
ProviderConfig.Type.SELECT,
|
||||
}:
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
||||
credentials[credential_name] = default_value
|
||||
|
||||
@ -39,8 +39,9 @@ class DatasourceEngine:
|
||||
"""
|
||||
try:
|
||||
# hit the callback handler
|
||||
workflow_tool_callback.on_datasource_start(datasource_name=datasource.entity.identity.name,
|
||||
datasource_inputs=datasource_parameters)
|
||||
workflow_tool_callback.on_datasource_start(
|
||||
datasource_name=datasource.entity.identity.name, datasource_inputs=datasource_parameters
|
||||
)
|
||||
|
||||
if datasource.runtime and datasource.runtime.runtime_parameters:
|
||||
datasource_parameters = {**datasource.runtime.runtime_parameters, **datasource_parameters}
|
||||
@ -86,7 +87,6 @@ class DatasourceEngine:
|
||||
workflow_tool_callback.on_tool_error(e)
|
||||
raise e
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _convert_datasource_response_to_str(datasource_response: list[DatasourceInvokeMessage]) -> str:
|
||||
"""
|
||||
@ -101,7 +101,10 @@ class DatasourceEngine:
|
||||
f"result link: {cast(DatasourceInvokeMessage.TextMessage, response.message).text}."
|
||||
+ " please tell user to check it."
|
||||
)
|
||||
elif response.type in {DatasourceInvokeMessage.MessageType.IMAGE_LINK, DatasourceInvokeMessage.MessageType.IMAGE}:
|
||||
elif response.type in {
|
||||
DatasourceInvokeMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
result += (
|
||||
"image has been created and sent to user already, "
|
||||
+ "you do not need to create it, just tell the user to check it now."
|
||||
@ -123,7 +126,10 @@ class DatasourceEngine:
|
||||
Extract datasource response binary
|
||||
"""
|
||||
for response in datasource_response:
|
||||
if response.type in {DatasourceInvokeMessage.MessageType.IMAGE_LINK, DatasourceInvokeMessage.MessageType.IMAGE}:
|
||||
if response.type in {
|
||||
DatasourceInvokeMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
mimetype = None
|
||||
if not response.meta:
|
||||
raise ValueError("missing meta data")
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import logging
|
||||
from threading import Lock
|
||||
from typing import Union
|
||||
@ -75,8 +74,7 @@ class DatasourceManager:
|
||||
return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name)
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user