This commit is contained in:
jyong
2025-05-23 15:55:41 +08:00
parent a49942b949
commit 64d997fdb0
16 changed files with 176 additions and 198 deletions

View File

@ -6,7 +6,11 @@ import contexts
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.common_entities import I18nObject
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.datasource.errors import DatasourceProviderNotFoundError
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.plugin.impl.datasource import PluginDatasourceManager
logger = logging.getLogger(__name__)
@ -19,7 +23,9 @@ class DatasourceManager:
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
def get_datasource_plugin_provider(cls, provider: str, tenant_id: str) -> DatasourcePluginProviderController:
def get_datasource_plugin_provider(
cls, provider: str, tenant_id: str, datasource_type: DatasourceProviderType
) -> DatasourcePluginProviderController:
"""
get the datasource plugin provider
"""
@ -40,12 +46,30 @@ class DatasourceManager:
if not provider_entity:
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
controller = DatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
match (datasource_type):
case DatasourceProviderType.ONLINE_DOCUMENT:
controller = OnlineDocumentDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.WEBSITE_CRAWL:
controller = WebsiteCrawlDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.LOCAL_FILE:
controller = LocalFileDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case _:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
datasource_plugin_providers[provider] = controller
@ -57,6 +81,7 @@ class DatasourceManager:
provider_id: str,
datasource_name: str,
tenant_id: str,
datasource_type: DatasourceProviderType,
) -> DatasourcePlugin:
"""
get the datasource runtime
@ -68,21 +93,10 @@ class DatasourceManager:
:return: the datasource plugin
"""
return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name)
return cls.get_datasource_plugin_provider(
provider_id,
tenant_id,
datasource_type,
).get_datasource(datasource_name)
@classmethod
def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]:
"""
list all the datasource providers
"""
manager = PluginDatasourceManager()
provider_entities = manager.fetch_datasource_providers(tenant_id)
return [
DatasourcePluginProviderController(
entity=provider.declaration,
plugin_id=provider.plugin_id,
plugin_unique_identifier=provider.plugin_unique_identifier,
tenant_id=tenant_id,
)
for provider in provider_entities
]