Merge branch 'feat/r2' into deploy/rag-dev

This commit is contained in:
Dongyu Li
2025-06-17 13:56:00 +08:00
15 changed files with 284 additions and 270 deletions

View File

@ -32,7 +32,11 @@ class DatasourceProviderService:
:param credentials:
"""
credential_valid = self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id, user_id=current_user.id, provider=provider, plugin_id=plugin_id, credentials=credentials
tenant_id=tenant_id,
user_id=current_user.id,
provider=provider,
plugin_id=plugin_id,
credentials=credentials
)
if credential_valid:
# Get all provider configurations of the current workspace
@ -104,7 +108,8 @@ class DatasourceProviderService:
for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}")
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}")
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
@ -144,7 +149,8 @@ class DatasourceProviderService:
for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}")
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}")
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
@ -161,7 +167,12 @@ class DatasourceProviderService:
return copy_credentials_list
def update_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict) -> None:
def update_datasource_credentials(self,
tenant_id: str,
auth_id: str,
provider: str,
plugin_id: str,
credentials: dict) -> None:
"""
update datasource credentials.
"""

View File

@ -15,7 +15,6 @@ import contexts
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import (
DatasourceInvokeMessage,
DatasourceProviderType,
OnlineDocumentPagesMessage,
WebsiteCrawlMessage,
@ -423,70 +422,71 @@ class RagPipelineService:
return workflow_node_execution
def run_datasource_workflow_node_status(
self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, datasource_type: str, is_published: bool
) -> dict:
"""
Run published workflow datasource
"""
if is_published:
# fetch published workflow by app_model
workflow = self.get_published_workflow(pipeline=pipeline)
else:
workflow = self.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
# run draft workflow node
datasource_node_data = None
start_at = time.perf_counter()
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get('provider_name'),
plugin_id=datasource_node_data.get('plugin_id'),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
match datasource_type:
case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_results: list[WebsiteCrawlMessage] = []
for website_message in datasource_runtime.get_website_crawl(
user_id=account.id,
datasource_parameters={"job_id": job_id},
provider_type=datasource_runtime.datasource_provider_type(),
):
website_crawl_results.append(website_message)
return {
"result": [result for result in website_crawl_results.result],
"status": website_crawl_results.result.status,
"provider_type": datasource_node_data.get("provider_type"),
}
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
# def run_datasource_workflow_node_status(
# self, pipeline: Pipeline, node_id: str, job_id: str, account: Account,
# datasource_type: str, is_published: bool
# ) -> dict:
# """
# Run published workflow datasource
# """
# if is_published:
# # fetch published workflow by app_model
# workflow = self.get_published_workflow(pipeline=pipeline)
# else:
# workflow = self.get_draft_workflow(pipeline=pipeline)
# if not workflow:
# raise ValueError("Workflow not initialized")
#
# # run draft workflow node
# datasource_node_data = None
# start_at = time.perf_counter()
# datasource_nodes = workflow.graph_dict.get("nodes", [])
# for datasource_node in datasource_nodes:
# if datasource_node.get("id") == node_id:
# datasource_node_data = datasource_node.get("data", {})
# break
# if not datasource_node_data:
# raise ValueError("Datasource node data not found")
#
# from core.datasource.datasource_manager import DatasourceManager
#
# datasource_runtime = DatasourceManager.get_datasource_runtime(
# provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
# datasource_name=datasource_node_data.get("datasource_name"),
# tenant_id=pipeline.tenant_id,
# datasource_type=DatasourceProviderType(datasource_type),
# )
# datasource_provider_service = DatasourceProviderService()
# credentials = datasource_provider_service.get_real_datasource_credentials(
# tenant_id=pipeline.tenant_id,
# provider=datasource_node_data.get('provider_name'),
# plugin_id=datasource_node_data.get('plugin_id'),
# )
# if credentials:
# datasource_runtime.runtime.credentials = credentials[0].get("credentials")
# match datasource_type:
#
# case DatasourceProviderType.WEBSITE_CRAWL:
# datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
# website_crawl_results: list[WebsiteCrawlMessage] = []
# for website_message in datasource_runtime.get_website_crawl(
# user_id=account.id,
# datasource_parameters={"job_id": job_id},
# provider_type=datasource_runtime.datasource_provider_type(),
# ):
# website_crawl_results.append(website_message)
# return {
# "result": [result for result in website_crawl_results.result],
# "status": website_crawl_results.result.status,
# "provider_type": datasource_node_data.get("provider_type"),
# }
# case _:
# raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
def run_datasource_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str,
is_published: bool
) -> Generator[DatasourceRunEvent, None, None]:
) -> Generator[str, None, None]:
"""
Run published workflow datasource
"""
@ -533,25 +533,40 @@ class RagPipelineService:
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_online_document_pages(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
for message in online_document_result:
yield DatasourceRunEvent(
status="success",
result=message.model_dump(),
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] =\
datasource_runtime.get_online_document_pages(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
start_time = time.time()
for message in online_document_result:
end_time = time.time()
online_document_event = DatasourceRunEvent(
status="completed",
data=message.result,
time_consuming=round(end_time - start_time, 2)
)
yield json.dumps(online_document_event.model_dump())
case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_website_crawl(
website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = datasource_runtime.get_website_crawl(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
yield from website_crawl_result
start_time = time.time()
for message in website_crawl_result:
end_time = time.time()
crawl_event = DatasourceRunEvent(
status=message.result.status,
data=message.result.web_info_list,
total=message.result.total,
completed=message.result.completed,
time_consuming = round(end_time - start_time, 2)
)
yield json.dumps(crawl_event.model_dump())
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
@ -952,7 +967,9 @@ class RagPipelineService:
if not dataset:
raise ValueError("Dataset not found")
max_position = db.session.query(func.max(PipelineCustomizedTemplate.position)).filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id).scalar()
max_position = db.session.query(
func.max(PipelineCustomizedTemplate.position)).filter(
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id).scalar()
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)