This commit is contained in:
jyong
2025-05-23 15:55:41 +08:00
parent a49942b949
commit 64d997fdb0
16 changed files with 176 additions and 198 deletions

View File

@ -99,6 +99,7 @@ class PipelineGenerator(BaseAppGenerator):
)
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
@ -118,7 +119,7 @@ class PipelineGenerator(BaseAppGenerator):
position=position,
account=user,
batch=batch,
document_form=pipeline.dataset.doc_form,
document_form=pipeline.dataset.chunk_structure,
)
db.session.add(document)
db.session.commit()
@ -231,7 +232,7 @@ class PipelineGenerator(BaseAppGenerator):
def single_iteration_generate(
self,
app_model: App,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
@ -255,7 +256,7 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required")
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(

View File

@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.entities.provider_entities import ProviderConfig
from core.plugin.impl.tool import PluginToolManager
@ -11,9 +10,11 @@ from core.tools.errors import ToolProviderCredentialValidationError
class DatasourcePluginProviderController(ABC):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
def __init__(self, entity: DatasourceProviderEntityWithPlugin) -> None:
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
self.entity = entity
self.tenant_id = tenant_id
@property
def need_credentials(self) -> bool:
@ -51,21 +52,6 @@ class DatasourcePluginProviderController(ABC):
"""
pass
def get_datasources(self) -> list[DatasourcePlugin]: # type: ignore
"""
get all datasources
"""
return [
DatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)
for datasource_entity in self.entity.datasources
]
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""
validate the format of the credentials of the provider and set the default value if needed

View File

@ -6,7 +6,11 @@ import contexts
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.common_entities import I18nObject
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.datasource.errors import DatasourceProviderNotFoundError
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.plugin.impl.datasource import PluginDatasourceManager
logger = logging.getLogger(__name__)
@ -19,7 +23,9 @@ class DatasourceManager:
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
def get_datasource_plugin_provider(cls, provider: str, tenant_id: str) -> DatasourcePluginProviderController:
def get_datasource_plugin_provider(
cls, provider: str, tenant_id: str, datasource_type: DatasourceProviderType
) -> DatasourcePluginProviderController:
"""
get the datasource plugin provider
"""
@ -40,12 +46,30 @@ class DatasourceManager:
if not provider_entity:
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
controller = DatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
match (datasource_type):
case DatasourceProviderType.ONLINE_DOCUMENT:
controller = OnlineDocumentDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.WEBSITE_CRAWL:
controller = WebsiteCrawlDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.LOCAL_FILE:
controller = LocalFileDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case _:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
datasource_plugin_providers[provider] = controller
@ -57,6 +81,7 @@ class DatasourceManager:
provider_id: str,
datasource_name: str,
tenant_id: str,
datasource_type: DatasourceProviderType,
) -> DatasourcePlugin:
"""
get the datasource runtime
@ -68,21 +93,10 @@ class DatasourceManager:
:return: the datasource plugin
"""
return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name)
return cls.get_datasource_plugin_provider(
provider_id,
tenant_id,
datasource_type,
).get_datasource(datasource_name)
@classmethod
def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]:
"""
list all the datasource providers
"""
manager = PluginDatasourceManager()
provider_entities = manager.fetch_datasource_providers(tenant_id)
return [
DatasourcePluginProviderController(
entity=provider.declaration,
plugin_id=provider.plugin_id,
plugin_unique_identifier=provider.plugin_unique_identifier,
tenant_id=tenant_id,
)
for provider in provider_entities
]

View File

@ -251,7 +251,7 @@ class GetOnlineDocumentPageContentRequest(BaseModel):
Get online document page content request
"""
online_document_info_list: list[OnlineDocumentInfo]
online_document_info: OnlineDocumentInfo
class OnlineDocumentPageContent(BaseModel):
@ -259,6 +259,7 @@ class OnlineDocumentPageContent(BaseModel):
Online document page content
"""
workspace_id: str = Field(..., description="The workspace id")
page_id: str = Field(..., description="The page id")
content: str = Field(..., description="The content of the page")
@ -268,7 +269,7 @@ class GetOnlineDocumentPageContentResponse(BaseModel):
Get online document page content response
"""
result: list[OnlineDocumentPageContent]
result: OnlineDocumentPageContent
class GetWebsiteCrawlRequest(BaseModel):
@ -286,7 +287,7 @@ class WebSiteInfo(BaseModel):
"""
source_url: str = Field(..., description="The url of the website")
markdown: str = Field(..., description="The markdown of the website")
content: str = Field(..., description="The content of the website")
title: str = Field(..., description="The title of the website")
description: str = Field(..., description="The description of the website")
@ -296,4 +297,4 @@ class GetWebsiteCrawlResponse(BaseModel):
Get website crawl response
"""
result: list[WebSiteInfo]
result: WebSiteInfo

View File

@ -26,12 +26,3 @@ class LocalFileDatasourcePlugin(DatasourcePlugin):
def datasource_provider_type(self) -> DatasourceProviderType:
return DatasourceProviderType.LOCAL_FILE
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return DatasourcePlugin(
entity=self.entity,
runtime=runtime,
tenant_id=self.tenant_id,
icon=self.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -8,15 +8,13 @@ from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlug
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity)
self.tenant_id = tenant_id
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier

View File

@ -69,12 +69,3 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
def datasource_provider_type(self) -> DatasourceProviderType:
return DatasourceProviderType.ONLINE_DOCUMENT
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return DatasourcePlugin(
entity=self.entity,
runtime=runtime,
tenant_id=self.tenant_id,
icon=self.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -1,20 +1,18 @@
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity)
self.tenant_id = tenant_id
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@ -25,7 +23,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC
"""
return DatasourceProviderType.ONLINE_DOCUMENT
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
@ -41,7 +39,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return DatasourcePlugin(
return OnlineDocumentDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,

View File

@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import (
GetWebsiteCrawlResponse,
)
from core.plugin.impl.datasource import PluginDatasourceManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
@ -38,9 +37,7 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
) -> GetWebsiteCrawlResponse:
manager = PluginDatasourceManager()
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
return manager.invoke_first_step(
return manager.get_website_crawl(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
@ -52,12 +49,3 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
def datasource_provider_type(self) -> DatasourceProviderType:
return DatasourceProviderType.WEBSITE_CRAWL
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return DatasourcePlugin(
entity=self.entity,
runtime=runtime,
tenant_id=self.tenant_id,
icon=self.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -1,20 +1,18 @@
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity)
self.tenant_id = tenant_id
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@ -25,7 +23,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
"""
return DatasourceProviderType.WEBSITE_CRAWL
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
@ -41,7 +39,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return DatasourcePlugin(
return WebsiteCrawlDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,

View File

@ -7,7 +7,6 @@ from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.oauth import OAuthSchema
from core.plugin.entities.parameters import (
PluginParameter,
PluginParameterOption,
@ -350,7 +349,6 @@ class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
plugin_id: Optional[str] = None
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: Optional[OAuthSchema] = Field(default=None, description="The oauth schema of the tool provider")
class ToolProviderEntityWithPlugin(ToolProviderEntity):

View File

@ -4,6 +4,9 @@ from typing import Any, cast
from core.datasource.entities.datasource_entities import (
DatasourceParameter,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse,
GetWebsiteCrawlRequest,
GetWebsiteCrawlResponse,
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
@ -54,6 +57,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
provider_id=node_data.provider_id,
datasource_name=node_data.datasource_name,
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType(node_data.provider_type),
)
except DatasourceNodeError as e:
yield RunCompletedEvent(
@ -82,38 +86,43 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
)
try:
# TODO: handle result
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
result = datasource_runtime._get_online_document_page_content(
user_id=self.user_id,
datasource_parameters=parameters,
provider_type=node_data.provider_type,
online_document_result: GetOnlineDocumentPageContentResponse = (
datasource_runtime._get_online_document_page_content(
user_id=self.user_id,
datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters),
provider_type=datasource_runtime.datasource_provider_type(),
)
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"result": result.result.model_dump(),
"datasource_type": datasource_runtime.datasource_provider_type,
},
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"online_document": online_document_result.result.model_dump(),
"datasource_type": datasource_runtime.datasource_provider_type,
},
)
)
elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
user_id=self.user_id,
datasource_parameters=parameters,
provider_type=node_data.provider_type,
datasource_parameters=GetWebsiteCrawlRequest(**parameters),
provider_type=datasource_runtime.datasource_provider_type(),
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"result": result.result.model_dump(),
"datasource_type": datasource_runtime.datasource_provider_type,
},
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"website": website_crawl_result.result.model_dump(),
"datasource_type": datasource_runtime.datasource_provider_type,
},
)
)
else:
raise DatasourceNodeError(