mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
r2
This commit is contained in:
@ -32,6 +32,7 @@ from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerat
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, Pipeline
|
||||
from models.model import AppMode
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -91,7 +92,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline,
|
||||
@ -107,19 +108,23 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
for datasource_info in datasource_info_list:
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
document_id = None
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
if invoke_from == InvokeFrom.PUBLISHED:
|
||||
position = DocumentService.get_documents_position(pipeline.dataset_id)
|
||||
position = DocumentService.get_documents_position(pipeline.dataset_id)
|
||||
document = self._build_document(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
dataset_id=pipeline.dataset_id,
|
||||
built_in_field_enabled=pipeline.dataset.built_in_field_enabled,
|
||||
built_in_field_enabled=dataset.built_in_field_enabled,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
created_from="rag-pipeline",
|
||||
position=position,
|
||||
account=user,
|
||||
batch=batch,
|
||||
document_form=pipeline.dataset.chunk_structure,
|
||||
document_form=dataset.chunk_structure,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
@ -127,10 +132,12 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
# init application generate entity
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
pipline_config=pipeline_config,
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
dataset_id=pipeline.dataset_id,
|
||||
dataset_id=dataset.id,
|
||||
start_node_id=start_node_id,
|
||||
batch=batch,
|
||||
document_id=document_id,
|
||||
inputs=self._prepare_user_inputs(
|
||||
@ -160,17 +167,28 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
return self._generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
else:
|
||||
self._generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@ -201,7 +219,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=pipeline.mode,
|
||||
app_mode=AppMode.RAG_PIPELINE,
|
||||
)
|
||||
|
||||
# new thread
|
||||
@ -256,12 +274,18 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=args["datasource_type"],
|
||||
datasource_info=args["datasource_info"],
|
||||
dataset_id=pipeline.dataset_id,
|
||||
batch=args["batch"],
|
||||
document_id=args["document_id"],
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
@ -288,7 +312,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
@ -299,7 +323,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
|
||||
def single_loop_generate(
|
||||
self,
|
||||
app_model: App,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
@ -323,7 +347,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
app_config = WorkflowAppConfigManager.get_app_config(pipeline=pipeline, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
@ -353,7 +377,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from typing import Optional, cast
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -12,6 +13,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Pipeline
|
||||
@ -100,6 +102,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id,
|
||||
SystemVariableKey.BATCH: self.application_generate_entity.batch,
|
||||
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
|
||||
SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type,
|
||||
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
|
||||
}
|
||||
|
||||
variable_pool = VariablePool(
|
||||
@ -110,7 +114,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
graph = self._init_rag_pipeline_graph(
|
||||
graph_config=workflow.graph_dict,
|
||||
start_node_id=self.application_generate_entity.start_node_id,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_entry = WorkflowEntry(
|
||||
@ -152,3 +159,43 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
|
||||
"""
|
||||
Init pipeline graph
|
||||
"""
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
nodes = graph_config.get("nodes", [])
|
||||
edges = graph_config.get("edges", [])
|
||||
real_run_nodes = []
|
||||
real_edges = []
|
||||
exclude_node_ids = []
|
||||
for node in nodes:
|
||||
node_id = node.get("id")
|
||||
node_type = node.get("data", {}).get("type", "")
|
||||
if node_type == "datasource":
|
||||
if start_node_id != node_id:
|
||||
exclude_node_ids.append(node_id)
|
||||
continue
|
||||
real_run_nodes.append(node)
|
||||
for edge in edges:
|
||||
if edge.get("source") in exclude_node_ids :
|
||||
continue
|
||||
real_edges.append(edge)
|
||||
graph_config = dict(graph_config)
|
||||
graph_config["nodes"] = real_run_nodes
|
||||
graph_config["edges"] = real_edges
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
return graph
|
||||
@ -233,14 +233,14 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||
"""
|
||||
RAG Pipeline Application Generate Entity.
|
||||
"""
|
||||
|
||||
# app config
|
||||
pipline_config: WorkflowUIBasedAppConfig
|
||||
# pipeline config
|
||||
pipeline_config: WorkflowUIBasedAppConfig
|
||||
datasource_type: str
|
||||
datasource_info: Mapping[str, Any]
|
||||
dataset_id: str
|
||||
batch: str
|
||||
document_id: str
|
||||
document_id: Optional[str] = None
|
||||
start_node_id: Optional[str] = None
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
|
||||
@ -18,3 +18,5 @@ class SystemVariableKey(StrEnum):
|
||||
DOCUMENT_ID = "document_id"
|
||||
BATCH = "batch"
|
||||
DATASET_ID = "dataset_id"
|
||||
DATASOURCE_TYPE = "datasource_type"
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
|
||||
@ -121,6 +121,8 @@ class Graph(BaseModel):
|
||||
# fetch nodes that have no predecessor node
|
||||
root_node_configs = []
|
||||
all_node_id_config_mapping: dict[str, dict] = {}
|
||||
|
||||
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
@ -140,7 +142,8 @@ class Graph(BaseModel):
|
||||
(
|
||||
node_config.get("id")
|
||||
for node_config in root_node_configs
|
||||
if node_config.get("data", {}).get("type", "") == NodeType.START.value
|
||||
if node_config.get("data", {}).get("type", "") == NodeType.START.value
|
||||
or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
@ -6,11 +6,8 @@ from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
GetOnlineDocumentPageContentResponse,
|
||||
GetWebsiteCrawlRequest,
|
||||
GetWebsiteCrawlResponse,
|
||||
)
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||
from core.file import File
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
@ -42,22 +39,23 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
"""
|
||||
|
||||
node_data = cast(DatasourceNodeData, self.node_data)
|
||||
|
||||
# fetch datasource icon
|
||||
datasource_info = {
|
||||
"provider_id": node_data.provider_id,
|
||||
"plugin_unique_identifier": node_data.plugin_unique_identifier,
|
||||
}
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# get datasource runtime
|
||||
try:
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
|
||||
|
||||
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=node_data.provider_id,
|
||||
datasource_name=node_data.datasource_name,
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType(node_data.provider_type),
|
||||
datasource_type=DatasourceProviderType(datasource_type),
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
@ -75,12 +73,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
datasource_parameters = datasource_runtime.entity.parameters
|
||||
parameters = self._generate_parameters(
|
||||
datasource_parameters=datasource_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
variable_pool=variable_pool,
|
||||
node_data=self.node_data,
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
datasource_parameters=datasource_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
variable_pool=variable_pool,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
)
|
||||
@ -106,20 +104,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
},
|
||||
)
|
||||
)
|
||||
elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
|
||||
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
|
||||
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=GetWebsiteCrawlRequest(**parameters),
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
elif (
|
||||
datasource_runtime.datasource_provider_type in (
|
||||
DatasourceProviderType.WEBSITE_CRAWL,
|
||||
DatasourceProviderType.LOCAL_FILE,
|
||||
)
|
||||
):
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"website": website_crawl_result.result.model_dump(),
|
||||
"website": datasource_info,
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user