This commit is contained in:
jyong
2025-05-23 19:30:48 +08:00
parent 70d2c78176
commit 6d547447d3
11 changed files with 157 additions and 72 deletions

View File

@ -32,6 +32,7 @@ from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerat
from extensions.ext_database import db
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, Pipeline
from models.model import AppMode
from services.dataset_service import DocumentService
logger = logging.getLogger(__name__)
@ -91,7 +92,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline,
@ -107,19 +108,23 @@ class PipelineGenerator(BaseAppGenerator):
for datasource_info in datasource_info_list:
workflow_run_id = str(uuid.uuid4())
document_id = None
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
if invoke_from == InvokeFrom.PUBLISHED:
position = DocumentService.get_documents_position(pipeline.dataset_id)
position = DocumentService.get_documents_position(pipeline.dataset_id)
document = self._build_document(
tenant_id=pipeline.tenant_id,
dataset_id=pipeline.dataset_id,
built_in_field_enabled=pipeline.dataset.built_in_field_enabled,
built_in_field_enabled=dataset.built_in_field_enabled,
datasource_type=datasource_type,
datasource_info=datasource_info,
created_from="rag-pipeline",
position=position,
account=user,
batch=batch,
document_form=pipeline.dataset.chunk_structure,
document_form=dataset.chunk_structure,
)
db.session.add(document)
db.session.commit()
@ -127,10 +132,12 @@ class PipelineGenerator(BaseAppGenerator):
# init application generate entity
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
pipline_config=pipeline_config,
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=datasource_type,
datasource_info=datasource_info,
dataset_id=pipeline.dataset_id,
dataset_id=dataset.id,
start_node_id=start_node_id,
batch=batch,
document_id=document_id,
inputs=self._prepare_user_inputs(
@ -160,17 +167,28 @@ class PipelineGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
return self._generate(
pipeline=pipeline,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
if invoke_from == InvokeFrom.DEBUGGER:
return self._generate(
pipeline=pipeline,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
else:
self._generate(
pipeline=pipeline,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
def _generate(
self,
@ -201,7 +219,7 @@ class PipelineGenerator(BaseAppGenerator):
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=pipeline.mode,
app_mode=AppMode.RAG_PIPELINE,
)
# new thread
@ -256,12 +274,18 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required")
# convert to app config
app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args["datasource_type"],
datasource_info=args["datasource_info"],
dataset_id=pipeline.dataset_id,
batch=args["batch"],
document_id=args["document_id"],
inputs={},
files=[],
user_id=user.id,
@ -288,7 +312,7 @@ class PipelineGenerator(BaseAppGenerator):
)
return self._generate(
app_model=app_model,
pipeline=pipeline,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
@ -299,7 +323,7 @@ class PipelineGenerator(BaseAppGenerator):
def single_loop_generate(
self,
app_model: App,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
@ -323,7 +347,7 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required")
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_config = WorkflowAppConfigManager.get_app_config(pipeline=pipeline, workflow=workflow)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
@ -353,7 +377,7 @@ class PipelineGenerator(BaseAppGenerator):
)
return self._generate(
app_model=app_model,
pipeline=pipeline,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,

View File

@ -1,5 +1,6 @@
import logging
from typing import Optional, cast
from collections.abc import Mapping
from typing import Any, Optional, cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -12,6 +13,7 @@ from core.app.entities.app_invoke_entities import (
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.dataset import Pipeline
@ -100,6 +102,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id,
SystemVariableKey.BATCH: self.application_generate_entity.batch,
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type,
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
}
variable_pool = VariablePool(
@ -110,7 +114,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
graph = self._init_rag_pipeline_graph(
graph_config=workflow.graph_dict,
start_node_id=self.application_generate_entity.start_node_id,
)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
@ -152,3 +159,43 @@ class PipelineRunner(WorkflowBasedAppRunner):
# return workflow
return workflow
def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
"""
Init pipeline graph
"""
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
nodes = graph_config.get("nodes", [])
edges = graph_config.get("edges", [])
real_run_nodes = []
real_edges = []
exclude_node_ids = []
for node in nodes:
node_id = node.get("id")
node_type = node.get("data", {}).get("type", "")
if node_type == "datasource":
if start_node_id != node_id:
exclude_node_ids.append(node_id)
continue
real_run_nodes.append(node)
for edge in edges:
if edge.get("source") in exclude_node_ids :
continue
real_edges.append(edge)
graph_config = dict(graph_config)
graph_config["nodes"] = real_run_nodes
graph_config["edges"] = real_edges
# init graph
graph = Graph.init(graph_config=graph_config)
if not graph:
raise ValueError("graph not found in workflow")
return graph