Merge branch 'main' into feat/r2

This commit is contained in:
Yeuoly
2025-04-27 14:31:19 +08:00
874 changed files with 31114 additions and 19811 deletions

View File

@ -18,10 +18,15 @@ class DatasourcePlugin:
plugin_unique_identifier: str
runtime_parameters: Optional[list[DatasourceParameter]]
entity: DatasourceEntity
runtime: DatasourceRuntime
runtime: DatasourceRuntime
def __init__(
self, entity: DatasourceEntity, runtime: DatasourceRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
self.entity = entity
self.runtime = runtime
@ -73,7 +78,6 @@ class DatasourcePlugin:
rag_pipeline_id=rag_pipeline_id,
)
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return DatasourcePlugin(
entity=self.entity,

View File

@ -50,7 +50,12 @@ class DatasourcePluginProviderController(BuiltinToolProviderController):
return datasource with given name
"""
datasource_entity = next(
(datasource_entity for datasource_entity in self.entity.datasources if datasource_entity.identity.name == datasource_name), None
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
@ -78,68 +83,68 @@ class DatasourcePluginProviderController(BuiltinToolProviderController):
)
for datasource_entity in self.entity.datasources
]
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""
validate the format of the credentials of the provider and set the default value if needed
"""
validate the format of the credentials of the provider and set the default value if needed
:param credentials: the credentials of the tool
"""
credentials_schema = dict[str, ProviderConfig]()
if credentials_schema is None:
return
:param credentials: the credentials of the tool
"""
credentials_schema = dict[str, ProviderConfig]()
if credentials_schema is None:
return
for credential in self.entity.credentials_schema:
credentials_schema[credential.name] = credential
for credential in self.entity.credentials_schema:
credentials_schema[credential.name] = credential
credentials_need_to_validate: dict[str, ProviderConfig] = {}
for credential_name in credentials_schema:
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
credentials_need_to_validate: dict[str, ProviderConfig] = {}
for credential_name in credentials_schema:
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
for credential_name in credentials:
if credential_name not in credentials_need_to_validate:
for credential_name in credentials:
if credential_name not in credentials_need_to_validate:
raise ToolProviderCredentialValidationError(
f"credential {credential_name} not found in provider {self.entity.identity.name}"
)
# check type
credential_schema = credentials_need_to_validate[credential_name]
if not credential_schema.required and credentials[credential_name] is None:
continue
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
elif credential_schema.type == ProviderConfig.Type.SELECT:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
options = credential_schema.options
if not isinstance(options, list):
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
if credentials[credential_name] not in [x.value for x in options]:
raise ToolProviderCredentialValidationError(
f"credential {credential_name} not found in provider {self.entity.identity.name}"
f"credential {credential_name} should be one of {options}"
)
# check type
credential_schema = credentials_need_to_validate[credential_name]
if not credential_schema.required and credentials[credential_name] is None:
continue
credentials_need_to_validate.pop(credential_name)
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
for credential_name in credentials_need_to_validate:
credential_schema = credentials_need_to_validate[credential_name]
if credential_schema.required:
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
elif credential_schema.type == ProviderConfig.Type.SELECT:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
# the credential is not set currently, set the default value if needed
if credential_schema.default is not None:
default_value = credential_schema.default
# parse default value into the correct type
if credential_schema.type in {
ProviderConfig.Type.SECRET_INPUT,
ProviderConfig.Type.TEXT_INPUT,
ProviderConfig.Type.SELECT,
}:
default_value = str(default_value)
options = credential_schema.options
if not isinstance(options, list):
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
if credentials[credential_name] not in [x.value for x in options]:
raise ToolProviderCredentialValidationError(
f"credential {credential_name} should be one of {options}"
)
credentials_need_to_validate.pop(credential_name)
for credential_name in credentials_need_to_validate:
credential_schema = credentials_need_to_validate[credential_name]
if credential_schema.required:
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
# the credential is not set currently, set the default value if needed
if credential_schema.default is not None:
default_value = credential_schema.default
# parse default value into the correct type
if credential_schema.type in {
ProviderConfig.Type.SECRET_INPUT,
ProviderConfig.Type.TEXT_INPUT,
ProviderConfig.Type.SELECT,
}:
default_value = str(default_value)
credentials[credential_name] = default_value
credentials[credential_name] = default_value

View File

@ -39,8 +39,9 @@ class DatasourceEngine:
"""
try:
# hit the callback handler
workflow_tool_callback.on_datasource_start(datasource_name=datasource.entity.identity.name,
datasource_inputs=datasource_parameters)
workflow_tool_callback.on_datasource_start(
datasource_name=datasource.entity.identity.name, datasource_inputs=datasource_parameters
)
if datasource.runtime and datasource.runtime.runtime_parameters:
datasource_parameters = {**datasource.runtime.runtime_parameters, **datasource_parameters}
@ -86,7 +87,6 @@ class DatasourceEngine:
workflow_tool_callback.on_tool_error(e)
raise e
@staticmethod
def _convert_datasource_response_to_str(datasource_response: list[DatasourceInvokeMessage]) -> str:
"""
@ -101,7 +101,10 @@ class DatasourceEngine:
f"result link: {cast(DatasourceInvokeMessage.TextMessage, response.message).text}."
+ " please tell user to check it."
)
elif response.type in {DatasourceInvokeMessage.MessageType.IMAGE_LINK, DatasourceInvokeMessage.MessageType.IMAGE}:
elif response.type in {
DatasourceInvokeMessage.MessageType.IMAGE_LINK,
DatasourceInvokeMessage.MessageType.IMAGE,
}:
result += (
"image has been created and sent to user already, "
+ "you do not need to create it, just tell the user to check it now."
@ -123,7 +126,10 @@ class DatasourceEngine:
Extract datasource response binary
"""
for response in datasource_response:
if response.type in {DatasourceInvokeMessage.MessageType.IMAGE_LINK, DatasourceInvokeMessage.MessageType.IMAGE}:
if response.type in {
DatasourceInvokeMessage.MessageType.IMAGE_LINK,
DatasourceInvokeMessage.MessageType.IMAGE,
}:
mimetype = None
if not response.meta:
raise ValueError("missing meta data")

View File

@ -1,4 +1,3 @@
import logging
from threading import Lock
from typing import Union
@ -75,8 +74,7 @@ class DatasourceManager:
return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name)
else:
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
@classmethod
def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]:
"""