Merge branch 'feat/datasource' into feat/r2

# Conflicts:
#	api/services/rag_pipeline/rag_pipeline.py
#	web/app/components/workflow/constants.ts
#	web/app/components/workflow/header/run-and-history.tsx
#	web/app/components/workflow/hooks/use-nodes-interactions.ts
#	web/app/components/workflow/hooks/use-workflow-interactions.ts
#	web/app/components/workflow/hooks/use-workflow.ts
#	web/app/components/workflow/index.tsx
#	web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx
#	web/app/components/workflow/nodes/_base/panel.tsx
#	web/app/components/workflow/nodes/code/use-config.ts
#	web/app/components/workflow/nodes/llm/default.ts
#	web/app/components/workflow/panel/index.tsx
#	web/app/components/workflow/panel/version-history-panel/index.tsx
#	web/app/components/workflow/store/workflow/index.ts
#	web/app/components/workflow/types.ts
#	web/config/index.ts
#	web/types/workflow.ts
This commit is contained in:
jyong
2025-07-02 14:01:59 +08:00
753 changed files with 28675 additions and 3917 deletions

View File

@ -86,6 +86,7 @@ from .datasets import (
)
from .datasets.rag_pipeline import (
datasource_auth,
datasource_content_preview,
rag_pipeline,
rag_pipeline_datasets,
rag_pipeline_import,

View File

@ -0,0 +1,54 @@
from flask_restful import ( # type: ignore
Resource, # type: ignore
reqparse,
)
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_user, login_required
from models import Account
from models.dataset import Pipeline
from services.rag_pipeline.rag_pipeline import RagPipelineService
class DataSourceContentPreviewApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
"""
Run datasource content preview
"""
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
args = parser.parse_args()
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService()
preview_content = rag_pipeline_service.run_datasource_node_preview(
pipeline=pipeline,
node_id=node_id,
user_inputs=inputs,
account=current_user,
datasource_type=datasource_type,
is_published=True,
)
return preview_content, 200
api.add_resource(
DataSourceContentPreviewApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview"
)

View File

@ -414,17 +414,19 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline,
node_id=node_id,
user_inputs=inputs,
account=current_user,
datasource_type=datasource_type,
is_published=True,
return helper.compact_generate_response(
PipelineGenerator.convert_to_event_stream(
rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline,
node_id=node_id,
user_inputs=inputs,
account=current_user,
datasource_type=datasource_type,
is_published=False,
)
)
)
return result
class RagPipelineDraftDatasourceNodeRunApi(Resource):
@setup_required
@ -455,21 +457,18 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService()
try:
return helper.compact_generate_response(
PipelineGenerator.convert_to_event_stream(
rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline,
node_id=node_id,
user_inputs=inputs,
account=current_user,
datasource_type=datasource_type,
is_published=False,
)
return helper.compact_generate_response(
PipelineGenerator.convert_to_event_stream(
rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline,
node_id=node_id,
user_inputs=inputs,
account=current_user,
datasource_type=datasource_type,
is_published=False,
)
)
except Exception as e:
print(e)
)
class RagPipelinePublishedNodeRunApi(Resource):

View File

@ -2,14 +2,14 @@ import contextvars
import datetime
import json
import logging
import random
import secrets
import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, copy_current_request_context, current_app, has_request_context
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker
@ -110,7 +110,7 @@ class PipelineGenerator(BaseAppGenerator):
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
documents = []
if invoke_from == InvokeFrom.PUBLISHED:
for datasource_info in datasource_info_list:
@ -589,7 +589,7 @@ class PipelineGenerator(BaseAppGenerator):
if datasource_type == "local_file":
name = datasource_info["name"]
elif datasource_type == "online_document":
name = datasource_info["page_title"]
name = datasource_info["page"]["page_name"]
elif datasource_type == "website_crawl":
name = datasource_info["title"]
else:

View File

@ -214,10 +214,11 @@ class OnlineDocumentPage(BaseModel):
"""
page_id: str = Field(..., description="The page id")
page_title: str = Field(..., description="The page title")
page_name: str = Field(..., description="The page title")
page_icon: Optional[dict] = Field(None, description="The page icon")
type: str = Field(..., description="The type of the page")
last_edited_time: str = Field(..., description="The last edited time")
parent_id: Optional[str] = Field(None, description="The parent page id")
class OnlineDocumentInfo(BaseModel):

View File

@ -141,7 +141,7 @@ class PluginDatasourceManager(BasePluginClient):
datasource_provider_id = GenericProviderID(datasource_provider)
response = self._request_with_plugin_daemon_response_stream(
return self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages",
OnlineDocumentPagesMessage,
@ -159,7 +159,6 @@ class PluginDatasourceManager(BasePluginClient):
"Content-Type": "application/json",
},
)
yield from response
def get_online_document_page_content(
self,
@ -177,7 +176,7 @@ class PluginDatasourceManager(BasePluginClient):
datasource_provider_id = GenericProviderID(datasource_provider)
response = self._request_with_plugin_daemon_response_stream(
return self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content",
DatasourceMessage,
@ -195,7 +194,6 @@ class PluginDatasourceManager(BasePluginClient):
"Content-Type": "application/json",
},
)
yield from response
def online_drive_browse_files(
self,

View File

@ -9,22 +9,30 @@ class DatasourceStreamEvent(Enum):
"""
Datasource Stream event
"""
PROCESSING = "datasource_processing"
COMPLETED = "datasource_completed"
ERROR = "datasource_error"
class BaseDatasourceEvent(BaseModel):
pass
class DatasourceErrorEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.ERROR.value
error: str = Field(..., description="error message")
class DatasourceCompletedEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.COMPLETED.value
data: Mapping[str,Any] | list = Field(..., description="result")
total: Optional[int] = Field(..., description="total")
completed: Optional[int] = Field(..., description="completed")
time_consuming: Optional[float] = Field(..., description="time consuming")
data: Mapping[str, Any] | list = Field(..., description="result")
total: Optional[int] = Field(default=0, description="total")
completed: Optional[int] = Field(default=0, description="completed")
time_consuming: Optional[float] = Field(default=0.0, description="time consuming")
class DatasourceProcessingEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.PROCESSING.value
total: Optional[int] = Field(..., description="total")
completed: Optional[int] = Field(..., description="completed")

View File

@ -68,12 +68,15 @@ class QAChunk(BaseModel):
question: str
answer: str
class QAStructureChunk(BaseModel):
"""
QAStructureChunk.
"""
qa_chunks: list[QAChunk]
class BaseDocumentTransformer(ABC):
"""Abstract base class for document transformation systems.

View File

@ -275,5 +275,3 @@ class AgentLogEvent(BaseAgentEvent):
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent

View File

@ -1,8 +1,9 @@
import json
import logging
import re
import threading
import time
from collections.abc import Callable, Generator, Sequence
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, cast
from uuid import uuid4
@ -15,16 +16,20 @@ import contexts
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import (
DatasourceMessage,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDocumentPagesMessage,
OnlineDriveBrowseFilesRequest,
OnlineDriveBrowseFilesResponse,
WebsiteCrawlMessage,
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.rag.entities.event import BaseDatasourceEvent, DatasourceCompletedEvent, DatasourceProcessingEvent
from core.rag.entities.event import (
BaseDatasourceEvent,
DatasourceCompletedEvent,
DatasourceErrorEvent,
DatasourceProcessingEvent,
)
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
@ -64,6 +69,8 @@ from services.entities.knowledge_entities.rag_pipeline_entities import (
from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
logger = logging.getLogger(__name__)
class RagPipelineService:
@classmethod
@ -116,14 +123,6 @@ class RagPipelineService:
)
if not customized_template:
raise ValueError("Customized pipeline template not found.")
# check template name is exist
template_name = template_info.name
if template_name:
template = db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
PipelineCustomizedTemplate.id != template_id).first()
if template:
raise ValueError("Template name is already exists")
customized_template.name = template_info.name
customized_template.description = template_info.description
customized_template.icon = template_info.icon_info.model_dump()
@ -434,157 +433,210 @@ class RagPipelineService:
return workflow_node_execution
# def run_datasource_workflow_node_status(
# self, pipeline: Pipeline, node_id: str, job_id: str, account: Account,
# datasource_type: str, is_published: bool
# ) -> dict:
# """
# Run published workflow datasource
# """
# if is_published:
# # fetch published workflow by app_model
# workflow = self.get_published_workflow(pipeline=pipeline)
# else:
# workflow = self.get_draft_workflow(pipeline=pipeline)
# if not workflow:
# raise ValueError("Workflow not initialized")
#
# # run draft workflow node
# datasource_node_data = None
# start_at = time.perf_counter()
# datasource_nodes = workflow.graph_dict.get("nodes", [])
# for datasource_node in datasource_nodes:
# if datasource_node.get("id") == node_id:
# datasource_node_data = datasource_node.get("data", {})
# break
# if not datasource_node_data:
# raise ValueError("Datasource node data not found")
#
# from core.datasource.datasource_manager import DatasourceManager
#
# datasource_runtime = DatasourceManager.get_datasource_runtime(
# provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
# datasource_name=datasource_node_data.get("datasource_name"),
# tenant_id=pipeline.tenant_id,
# datasource_type=DatasourceProviderType(datasource_type),
# )
# datasource_provider_service = DatasourceProviderService()
# credentials = datasource_provider_service.get_real_datasource_credentials(
# tenant_id=pipeline.tenant_id,
# provider=datasource_node_data.get('provider_name'),
# plugin_id=datasource_node_data.get('plugin_id'),
# )
# if credentials:
# datasource_runtime.runtime.credentials = credentials[0].get("credentials")
# match datasource_type:
#
# case DatasourceProviderType.WEBSITE_CRAWL:
# datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
# website_crawl_results: list[WebsiteCrawlMessage] = []
# for website_message in datasource_runtime.get_website_crawl(
# user_id=account.id,
# datasource_parameters={"job_id": job_id},
# provider_type=datasource_runtime.datasource_provider_type(),
# ):
# website_crawl_results.append(website_message)
# return {
# "result": [result for result in website_crawl_results.result],
# "status": website_crawl_results.result.status,
# "provider_type": datasource_node_data.get("provider_type"),
# }
# case _:
# raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
def run_datasource_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str,
is_published: bool
self,
pipeline: Pipeline,
node_id: str,
user_inputs: dict,
account: Account,
datasource_type: str,
is_published: bool,
) -> Generator[BaseDatasourceEvent, None, None]:
"""
Run published workflow datasource
"""
if is_published:
# fetch published workflow by app_model
workflow = self.get_published_workflow(pipeline=pipeline)
else:
workflow = self.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
try:
if is_published:
# fetch published workflow by app_model
workflow = self.get_published_workflow(pipeline=pipeline)
else:
workflow = self.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
# run draft workflow node
datasource_node_data = None
start_at = time.perf_counter()
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
# run draft workflow node
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for key, value in datasource_parameters.items():
if not user_inputs.get(key):
user_inputs[key] = value["value"]
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for key, value in datasource_parameters.items():
if not user_inputs.get(key):
user_inputs[key] = value["value"]
from core.datasource.datasource_manager import DatasourceManager
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
)
start_time = time.time()
for message in online_document_result:
end_time = time.time()
online_document_event = DatasourceCompletedEvent(
data=message.result,
time_consuming=round(end_time - start_time, 2),
)
yield online_document_event.model_dump()
case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = datasource_runtime.get_website_crawl(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
start_time = time.time()
for message in website_crawl_result:
end_time = time.time()
if message.result.status == "completed":
crawl_event = DatasourceCompletedEvent(
data=message.result.web_info_list,
total=message.result.total,
completed=message.result.completed,
time_consuming=round(end_time - start_time, 2)
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
else:
crawl_event = DatasourceProcessingEvent(
total=message.result.total,
completed=message.result.completed,
)
start_time = time.time()
start_event = DatasourceProcessingEvent(
total=0,
completed=0,
)
yield start_event.model_dump()
try:
for message in online_document_result:
end_time = time.time()
online_document_event = DatasourceCompletedEvent(
data=message.result, time_consuming=round(end_time - start_time, 2)
)
yield online_document_event.model_dump()
except Exception as e:
logger.exception("Error during online document.")
yield DatasourceErrorEvent(error=str(e)).model_dump()
case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = (
datasource_runtime.get_website_crawl(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
yield crawl_event.model_dump()
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
)
start_time = time.time()
try:
for message in website_crawl_result:
end_time = time.time()
if message.result.status == "completed":
crawl_event = DatasourceCompletedEvent(
data=message.result.web_info_list,
total=message.result.total,
completed=message.result.completed,
time_consuming=round(end_time - start_time, 2),
)
else:
crawl_event = DatasourceProcessingEvent(
total=message.result.total,
completed=message.result.completed,
)
yield crawl_event.model_dump()
except Exception as e:
logger.exception("Error during website crawl.")
yield DatasourceErrorEvent(error=str(e)).model_dump()
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
except Exception as e:
logger.exception("Error in run_datasource_workflow_node.")
yield DatasourceErrorEvent(error=str(e)).model_dump()
def run_datasource_node_preview(
self,
pipeline: Pipeline,
node_id: str,
user_inputs: dict,
account: Account,
datasource_type: str,
is_published: bool,
) -> Mapping[str, Any]:
"""
Run published workflow datasource
"""
try:
if is_published:
# fetch published workflow by app_model
workflow = self.get_published_workflow(pipeline=pipeline)
else:
workflow = self.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
# run draft workflow node
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for key, value in datasource_parameters.items():
if not user_inputs.get(key):
user_inputs[key] = value["value"]
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[DatasourceMessage, None, None] = (
datasource_runtime.get_online_document_page_content(
user_id=account.id,
datasource_parameters=GetOnlineDocumentPageContentRequest(
workspace_id=user_inputs.get("workspace_id"),
page_id=user_inputs.get("page_id"),
type=user_inputs.get("type"),
),
provider_type=datasource_type,
)
)
try:
variables: dict[str, Any] = {}
for message in online_document_result:
if message.type == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
else:
variables[variable_name] = variable_value
return variables
except Exception as e:
logger.exception("Error during get online document content.")
raise RuntimeError(str(e))
#TODO Online Drive
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
except Exception as e:
logger.exception("Error in run_datasource_node_preview.")
raise RuntimeError(str(e))
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
@ -755,24 +807,77 @@ class RagPipelineService:
return workflow
def get_first_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]:
def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
"""
Get second step parameters of rag pipeline
"""
workflow = self.get_published_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
# get second step node
rag_pipeline_variables = workflow.rag_pipeline_variables
if not rag_pipeline_variables:
return []
# get datasource provider
datasource_provider_variables = [
item
for item in rag_pipeline_variables
if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
]
return datasource_provider_variables
def get_published_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
"""
Get first step parameters of rag pipeline
"""
workflow = self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline)
if not workflow:
published_workflow = self.get_published_workflow(pipeline=pipeline)
if not published_workflow:
raise ValueError("Workflow not initialized")
# get second step node
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
datasource_nodes = published_workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
variables = workflow.rag_pipeline_variables
variables = datasource_node_data.get("variables", {})
if variables:
variables_map = {item["variable"]: item for item in variables}
else:
return []
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
user_input_variables = []
for key, value in datasource_parameters.items():
if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]):
user_input_variables.append(variables_map.get(key, {}))
return user_input_variables
def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
"""
Get first step parameters of rag pipeline
"""
draft_workflow = self.get_draft_workflow(pipeline=pipeline)
if not draft_workflow:
raise ValueError("Workflow not initialized")
# get second step node
datasource_node_data = None
datasource_nodes = draft_workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
variables = datasource_node_data.get("variables", {})
if variables:
variables_map = {item["variable"]: item for item in variables}
else:
@ -781,21 +886,16 @@ class RagPipelineService:
user_input_variables = []
for key, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split('.')[-1]
user_input_variables.append(variables_map.get(last_part, {}))
if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]):
user_input_variables.append(variables_map.get(key, {}))
return user_input_variables
def get_second_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]:
def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
"""
Get second step parameters of rag pipeline
"""
workflow = self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline)
workflow = self.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
@ -803,32 +903,13 @@ class RagPipelineService:
rag_pipeline_variables = workflow.rag_pipeline_variables
if not rag_pipeline_variables:
return []
variables_map = {item["variable"]: item for item in rag_pipeline_variables}
# get datasource node data
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if datasource_node_data:
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for key, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split('.')[-1]
variables_map.pop(last_part)
all_second_step_variables = list(variables_map.values())
# get datasource provider
datasource_provider_variables = [
item
for item in all_second_step_variables
if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
]
item
for item in rag_pipeline_variables
if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
]
return datasource_provider_variables
def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
@ -950,16 +1031,6 @@ class RagPipelineService:
if not dataset:
raise ValueError("Dataset not found")
# check template name is exist
template_name = args.get("name")
if template_name:
template = db.session.query(PipelineCustomizedTemplate).filter(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
).first()
if template:
raise ValueError("Template name is already exists")
max_position = (
db.session.query(func.max(PipelineCustomizedTemplate.position))
.filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)