mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
r2
This commit is contained in:
@ -1,18 +1,13 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
|
||||
|
||||
class DatasourcePlugin:
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
class DatasourcePlugin(ABC):
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
@ -20,57 +15,19 @@ class DatasourcePlugin:
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def _invoke_first_step(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: dict[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||
|
||||
return manager.invoke_first_step(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
)
|
||||
|
||||
def _invoke_second_step(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: dict[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||
|
||||
return manager.invoke_second_step(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
)
|
||||
@abstractmethod
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the datasource provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return DatasourcePlugin(
|
||||
entity=self.entity,
|
||||
return self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
@ -1,26 +1,19 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class DatasourcePluginProviderController:
|
||||
class DatasourcePluginProviderController(ABC):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
def __init__(self, entity: DatasourceProviderEntityWithPlugin) -> None:
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
@ -44,29 +37,19 @@ class DatasourcePluginProviderController:
|
||||
):
|
||||
raise ToolProviderCredentialValidationError("Invalid credentials")
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
@abstractmethod
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return DatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
pass
|
||||
|
||||
def get_datasources(self) -> list[DatasourcePlugin]: # type: ignore
|
||||
"""
|
||||
|
||||
@ -28,13 +28,13 @@ class DatasourceProviderApiEntity(BaseModel):
|
||||
description: I18nObject
|
||||
icon: str | dict
|
||||
label: I18nObject # label
|
||||
type: ToolProviderType
|
||||
type: str
|
||||
masked_credentials: Optional[dict] = None
|
||||
original_credentials: Optional[dict] = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool")
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
|
||||
plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource")
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource")
|
||||
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class DatasourceProviderType(enum.StrEnum):
|
||||
|
||||
ONLINE_DOCUMENT = "online_document"
|
||||
LOCAL_FILE = "local_file"
|
||||
WEBSITE = "website"
|
||||
WEBSITE_CRAWL = "website_crawl"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "DatasourceProviderType":
|
||||
@ -111,10 +111,10 @@ class DatasourceParameter(PluginParameter):
|
||||
|
||||
|
||||
class DatasourceIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
provider: str = Field(..., description="The provider of the tool")
|
||||
author: str = Field(..., description="The author of the datasource")
|
||||
name: str = Field(..., description="The name of the datasource")
|
||||
label: I18nObject = Field(..., description="The label of the datasource")
|
||||
provider: str = Field(..., description="The provider of the datasource")
|
||||
icon: Optional[str] = None
|
||||
|
||||
|
||||
@ -145,7 +145,7 @@ class DatasourceProviderEntity(ToolProviderEntity):
|
||||
|
||||
|
||||
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
|
||||
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
||||
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DatasourceInvokeMeta(BaseModel):
|
||||
@ -195,3 +195,105 @@ class DatasourceInvokeFrom(Enum):
|
||||
"""
|
||||
|
||||
RAG_PIPELINE = "rag_pipeline"
|
||||
|
||||
|
||||
class GetOnlineDocumentPagesRequest(BaseModel):
|
||||
"""
|
||||
Get online document pages request
|
||||
"""
|
||||
|
||||
tenant_id: str = Field(..., description="The tenant id")
|
||||
|
||||
|
||||
class OnlineDocumentPageIcon(BaseModel):
|
||||
"""
|
||||
Online document page icon
|
||||
"""
|
||||
|
||||
type: str = Field(..., description="The type of the icon")
|
||||
url: str = Field(..., description="The url of the icon")
|
||||
|
||||
|
||||
class OnlineDocumentPage(BaseModel):
|
||||
"""
|
||||
Online document page
|
||||
"""
|
||||
|
||||
page_id: str = Field(..., description="The page id")
|
||||
page_title: str = Field(..., description="The page title")
|
||||
page_icon: Optional[OnlineDocumentPageIcon] = Field(None, description="The page icon")
|
||||
type: str = Field(..., description="The type of the page")
|
||||
last_edited_time: str = Field(..., description="The last edited time")
|
||||
|
||||
|
||||
class OnlineDocumentInfo(BaseModel):
|
||||
"""
|
||||
Online document info
|
||||
"""
|
||||
|
||||
workspace_id: str = Field(..., description="The workspace id")
|
||||
workspace_name: str = Field(..., description="The workspace name")
|
||||
workspace_icon: str = Field(..., description="The workspace icon")
|
||||
total: int = Field(..., description="The total number of documents")
|
||||
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
|
||||
|
||||
|
||||
class GetOnlineDocumentPagesResponse(BaseModel):
|
||||
"""
|
||||
Get online document pages response
|
||||
"""
|
||||
|
||||
result: list[OnlineDocumentInfo]
|
||||
|
||||
|
||||
class GetOnlineDocumentPageContentRequest(BaseModel):
|
||||
"""
|
||||
Get online document page content request
|
||||
"""
|
||||
|
||||
online_document_info_list: list[OnlineDocumentInfo]
|
||||
|
||||
|
||||
class OnlineDocumentPageContent(BaseModel):
|
||||
"""
|
||||
Online document page content
|
||||
"""
|
||||
|
||||
page_id: str = Field(..., description="The page id")
|
||||
content: str = Field(..., description="The content of the page")
|
||||
|
||||
|
||||
class GetOnlineDocumentPageContentResponse(BaseModel):
|
||||
"""
|
||||
Get online document page content response
|
||||
"""
|
||||
|
||||
result: list[OnlineDocumentPageContent]
|
||||
|
||||
|
||||
class GetWebsiteCrawlRequest(BaseModel):
|
||||
"""
|
||||
Get website crawl request
|
||||
"""
|
||||
|
||||
url: str = Field(..., description="The url of the website")
|
||||
crawl_parameters: dict = Field(..., description="The crawl parameters")
|
||||
|
||||
|
||||
class WebSiteInfo(BaseModel):
|
||||
"""
|
||||
Website info
|
||||
"""
|
||||
|
||||
source_url: str = Field(..., description="The url of the website")
|
||||
markdown: str = Field(..., description="The markdown of the website")
|
||||
title: str = Field(..., description="The title of the website")
|
||||
description: str = Field(..., description="The description of the website")
|
||||
|
||||
|
||||
class GetWebsiteCrawlResponse(BaseModel):
|
||||
"""
|
||||
Get website crawl response
|
||||
"""
|
||||
|
||||
result: list[WebSiteInfo]
|
||||
|
||||
37
api/core/datasource/local_file/local_file_plugin.py
Normal file
37
api/core/datasource/local_file/local_file_plugin.py
Normal file
@ -0,0 +1,37 @@
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
|
||||
|
||||
class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return DatasourcePlugin(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
58
api/core/datasource/local_file/local_file_provider.py
Normal file
58
api/core/datasource/local_file/local_file_provider.py
Normal file
@ -0,0 +1,58 @@
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
|
||||
|
||||
|
||||
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity)
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return LocalFileDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
@ -0,0 +1,80 @@
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
GetOnlineDocumentPageContentResponse,
|
||||
GetOnlineDocumentPagesRequest,
|
||||
GetOnlineDocumentPagesResponse,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
|
||||
class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def _get_online_document_pages(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: GetOnlineDocumentPagesRequest,
|
||||
provider_type: str,
|
||||
) -> GetOnlineDocumentPagesResponse:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.get_online_document_pages(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def _get_online_document_page_content(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||
provider_type: str,
|
||||
) -> GetOnlineDocumentPageContentResponse:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.get_online_document_page_content(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return DatasourcePlugin(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
@ -0,0 +1,50 @@
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
|
||||
|
||||
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity)
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return DatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
63
api/core/datasource/website_crawl/website_crawl_plugin.py
Normal file
63
api/core/datasource/website_crawl/website_crawl_plugin.py
Normal file
@ -0,0 +1,63 @@
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
GetWebsiteCrawlRequest,
|
||||
GetWebsiteCrawlResponse,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
|
||||
|
||||
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def _get_website_crawl(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: GetWebsiteCrawlRequest,
|
||||
provider_type: str,
|
||||
) -> GetWebsiteCrawlResponse:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||
|
||||
return manager.invoke_first_step(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return DatasourcePlugin(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
50
api/core/datasource/website_crawl/website_crawl_provider.py
Normal file
50
api/core/datasource/website_crawl/website_crawl_provider.py
Normal file
@ -0,0 +1,50 @@
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
|
||||
|
||||
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity)
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return DatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
Reference in New Issue
Block a user