feat(online_driver): add online driver plugin, support browsing and downloading

This commit is contained in:
Harry
2025-06-27 16:41:39 +08:00
parent efccbe4039
commit eee72101f4
4 changed files with 251 additions and 0 deletions

View File

@ -26,6 +26,7 @@ class DatasourceProviderType(enum.StrEnum):
ONLINE_DOCUMENT = "online_document"
LOCAL_FILE = "local_file"
WEBSITE_CRAWL = "website_crawl"
ONLINE_DRIVER = "online_driver"
@classmethod
def value_of(cls, value: str) -> "DatasourceProviderType":
@ -303,3 +304,57 @@ class WebsiteCrawlMessage(BaseModel):
class DatasourceMessage(ToolInvokeMessage):
pass
#########################
# Online driver file
#########################
class OnlineDriverFile(BaseModel):
"""
Online driver file
"""
key: str = Field(..., description="The key of the file")
size: int = Field(..., description="The size of the file")
class OnlineDriverFileBucket(BaseModel):
"""
Online driver file bucket
"""
bucket: Optional[str] = Field(None, description="The bucket of the file")
files: list[OnlineDriverFile] = Field(..., description="The files of the bucket")
is_truncated: bool = Field(False, description="Whether the bucket has more files")
class OnlineDriverBrowseFilesRequest(BaseModel):
"""
Get online driver file list request
"""
prefix: Optional[str] = Field(None, description="File path prefix for filtering eg: 'docs/dify/'")
bucket: Optional[str] = Field(None, description="Storage bucket name")
max_keys: int = Field(20, description="Maximum number of files to return")
start_after: Optional[str] = Field(
None, description="Pagination token for continuing from a specific file eg: 'docs/dify/1.txt'"
)
class OnlineDriverBrowseFilesResponse(BaseModel):
"""
Get online driver file list response
"""
result: list[OnlineDriverFileBucket] = Field(..., description="The bucket of the files")
class OnlineDriverDownloadFileRequest(BaseModel):
"""
Get online driver file
"""
key: str = Field(..., description="The name of the file")
bucket: Optional[str] = Field(None, description="The name of the bucket")

View File

@ -0,0 +1,73 @@
from collections.abc import Generator
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceMessage,
DatasourceProviderType,
OnlineDriverBrowseFilesRequest,
OnlineDriverBrowseFilesResponse,
OnlineDriverDownloadFileRequest,
)
from core.plugin.impl.datasource import PluginDatasourceManager
class OnlineDriverDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def online_driver_browse_files(
self,
user_id: str,
request: OnlineDriverBrowseFilesRequest,
provider_type: str,
) -> Generator[OnlineDriverBrowseFilesResponse, None, None]:
manager = PluginDatasourceManager()
return manager.online_driver_browse_files(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
request=request,
provider_type=provider_type,
)
def online_driver_download_file(
self,
user_id: str,
request: OnlineDriverDownloadFileRequest,
provider_type: str,
) -> Generator[DatasourceMessage, None, None]:
manager = PluginDatasourceManager()
return manager.online_driver_download_file(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
request=request,
provider_type=provider_type,
)
def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DRIVER

View File

@ -0,0 +1,48 @@
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.online_driver.online_driver_plugin import OnlineDriverDatasourcePlugin
class OnlineDriverDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.ONLINE_DRIVER
def get_datasource(self, datasource_name: str) -> OnlineDriverDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return OnlineDriverDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)