mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
r2
This commit is contained in:
@ -161,7 +161,7 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
args = parser.parse_args()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=args,
|
||||
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(**args),
|
||||
)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
@ -8,7 +8,6 @@ from flask_restful.inputs import int_range # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
from models.model import EndUser
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
@ -40,6 +39,7 @@ from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.model import EndUser
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||
@ -242,7 +242,7 @@ class DraftRagPipelineRunApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
parser.add_argument("datasource_info", type=list, required=True, location="json")
|
||||
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -320,6 +320,9 @@ class RagPipelineDatasourceNodeRunApi(Resource):
|
||||
inputs = args.get("inputs")
|
||||
if inputs == None:
|
||||
raise ValueError("missing inputs")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type == None:
|
||||
raise ValueError("missing datasource_type")
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
result = rag_pipeline_service.run_datasource_workflow_node(
|
||||
@ -327,7 +330,7 @@ class RagPipelineDatasourceNodeRunApi(Resource):
|
||||
node_id=node_id,
|
||||
user_inputs=inputs,
|
||||
account=current_user,
|
||||
datasource_type=args.get("datasource_type"),
|
||||
datasource_type=datasource_type,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@ -32,6 +32,7 @@ from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerat
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, Pipeline
|
||||
from models.model import AppMode
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -91,7 +92,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline,
|
||||
@ -107,19 +108,23 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
for datasource_info in datasource_info_list:
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
document_id = None
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
if invoke_from == InvokeFrom.PUBLISHED:
|
||||
position = DocumentService.get_documents_position(pipeline.dataset_id)
|
||||
position = DocumentService.get_documents_position(pipeline.dataset_id)
|
||||
document = self._build_document(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
dataset_id=pipeline.dataset_id,
|
||||
built_in_field_enabled=pipeline.dataset.built_in_field_enabled,
|
||||
built_in_field_enabled=dataset.built_in_field_enabled,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
created_from="rag-pipeline",
|
||||
position=position,
|
||||
account=user,
|
||||
batch=batch,
|
||||
document_form=pipeline.dataset.chunk_structure,
|
||||
document_form=dataset.chunk_structure,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
@ -127,10 +132,12 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
# init application generate entity
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
pipline_config=pipeline_config,
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
dataset_id=pipeline.dataset_id,
|
||||
dataset_id=dataset.id,
|
||||
start_node_id=start_node_id,
|
||||
batch=batch,
|
||||
document_id=document_id,
|
||||
inputs=self._prepare_user_inputs(
|
||||
@ -160,17 +167,28 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
return self._generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
else:
|
||||
self._generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@ -201,7 +219,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=pipeline.mode,
|
||||
app_mode=AppMode.RAG_PIPELINE,
|
||||
)
|
||||
|
||||
# new thread
|
||||
@ -256,12 +274,18 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=args["datasource_type"],
|
||||
datasource_info=args["datasource_info"],
|
||||
dataset_id=pipeline.dataset_id,
|
||||
batch=args["batch"],
|
||||
document_id=args["document_id"],
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
@ -288,7 +312,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
@ -299,7 +323,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
|
||||
def single_loop_generate(
|
||||
self,
|
||||
app_model: App,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
@ -323,7 +347,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
app_config = WorkflowAppConfigManager.get_app_config(pipeline=pipeline, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
@ -353,7 +377,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from typing import Optional, cast
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -12,6 +13,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Pipeline
|
||||
@ -100,6 +102,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id,
|
||||
SystemVariableKey.BATCH: self.application_generate_entity.batch,
|
||||
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
|
||||
SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type,
|
||||
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
|
||||
}
|
||||
|
||||
variable_pool = VariablePool(
|
||||
@ -110,7 +114,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
graph = self._init_rag_pipeline_graph(
|
||||
graph_config=workflow.graph_dict,
|
||||
start_node_id=self.application_generate_entity.start_node_id,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_entry = WorkflowEntry(
|
||||
@ -152,3 +159,43 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
|
||||
"""
|
||||
Init pipeline graph
|
||||
"""
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
nodes = graph_config.get("nodes", [])
|
||||
edges = graph_config.get("edges", [])
|
||||
real_run_nodes = []
|
||||
real_edges = []
|
||||
exclude_node_ids = []
|
||||
for node in nodes:
|
||||
node_id = node.get("id")
|
||||
node_type = node.get("data", {}).get("type", "")
|
||||
if node_type == "datasource":
|
||||
if start_node_id != node_id:
|
||||
exclude_node_ids.append(node_id)
|
||||
continue
|
||||
real_run_nodes.append(node)
|
||||
for edge in edges:
|
||||
if edge.get("source") in exclude_node_ids :
|
||||
continue
|
||||
real_edges.append(edge)
|
||||
graph_config = dict(graph_config)
|
||||
graph_config["nodes"] = real_run_nodes
|
||||
graph_config["edges"] = real_edges
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
return graph
|
||||
@ -233,14 +233,14 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||
"""
|
||||
RAG Pipeline Application Generate Entity.
|
||||
"""
|
||||
|
||||
# app config
|
||||
pipline_config: WorkflowUIBasedAppConfig
|
||||
# pipeline config
|
||||
pipeline_config: WorkflowUIBasedAppConfig
|
||||
datasource_type: str
|
||||
datasource_info: Mapping[str, Any]
|
||||
dataset_id: str
|
||||
batch: str
|
||||
document_id: str
|
||||
document_id: Optional[str] = None
|
||||
start_node_id: Optional[str] = None
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
|
||||
@ -18,3 +18,5 @@ class SystemVariableKey(StrEnum):
|
||||
DOCUMENT_ID = "document_id"
|
||||
BATCH = "batch"
|
||||
DATASET_ID = "dataset_id"
|
||||
DATASOURCE_TYPE = "datasource_type"
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
|
||||
@ -121,6 +121,8 @@ class Graph(BaseModel):
|
||||
# fetch nodes that have no predecessor node
|
||||
root_node_configs = []
|
||||
all_node_id_config_mapping: dict[str, dict] = {}
|
||||
|
||||
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
@ -140,7 +142,8 @@ class Graph(BaseModel):
|
||||
(
|
||||
node_config.get("id")
|
||||
for node_config in root_node_configs
|
||||
if node_config.get("data", {}).get("type", "") == NodeType.START.value
|
||||
if node_config.get("data", {}).get("type", "") == NodeType.START.value
|
||||
or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
@ -6,11 +6,8 @@ from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
GetOnlineDocumentPageContentResponse,
|
||||
GetWebsiteCrawlRequest,
|
||||
GetWebsiteCrawlResponse,
|
||||
)
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||
from core.file import File
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
@ -42,22 +39,23 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
"""
|
||||
|
||||
node_data = cast(DatasourceNodeData, self.node_data)
|
||||
|
||||
# fetch datasource icon
|
||||
datasource_info = {
|
||||
"provider_id": node_data.provider_id,
|
||||
"plugin_unique_identifier": node_data.plugin_unique_identifier,
|
||||
}
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# get datasource runtime
|
||||
try:
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
|
||||
|
||||
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=node_data.provider_id,
|
||||
datasource_name=node_data.datasource_name,
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType(node_data.provider_type),
|
||||
datasource_type=DatasourceProviderType(datasource_type),
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
@ -75,12 +73,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
datasource_parameters = datasource_runtime.entity.parameters
|
||||
parameters = self._generate_parameters(
|
||||
datasource_parameters=datasource_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
variable_pool=variable_pool,
|
||||
node_data=self.node_data,
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
datasource_parameters=datasource_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
variable_pool=variable_pool,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
)
|
||||
@ -106,20 +104,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
},
|
||||
)
|
||||
)
|
||||
elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
|
||||
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
|
||||
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=GetWebsiteCrawlRequest(**parameters),
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
elif (
|
||||
datasource_runtime.datasource_provider_type in (
|
||||
DatasourceProviderType.WEBSITE_CRAWL,
|
||||
DatasourceProviderType.LOCAL_FILE,
|
||||
)
|
||||
):
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"website": website_crawl_result.result.model_dump(),
|
||||
"website": datasource_info,
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
},
|
||||
)
|
||||
|
||||
@ -6,7 +6,7 @@ import random
|
||||
import time
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func, select
|
||||
@ -298,13 +298,14 @@ class DatasetService:
|
||||
description=rag_pipeline_dataset_create_entity.description,
|
||||
permission=rag_pipeline_dataset_create_entity.permission,
|
||||
provider="vendor",
|
||||
runtime_mode="rag_pipeline",
|
||||
runtime_mode="rag-pipeline",
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
account = cast(Account, current_user)
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
|
||||
account=current_user,
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=dataset,
|
||||
|
||||
@ -59,12 +59,12 @@ class RagPipelineService:
|
||||
if not result.get("pipeline_templates") and language != "en-US":
|
||||
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
|
||||
result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
|
||||
return result.get("pipeline_templates")
|
||||
return [PipelineBuiltInTemplate(**template) for template in result.get("pipeline_templates", [])]
|
||||
else:
|
||||
mode = "customized"
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
result = retrieval_instance.get_pipeline_templates(language)
|
||||
return result.get("pipeline_templates")
|
||||
return [PipelineCustomizedTemplate(**template) for template in result.get("pipeline_templates", [])]
|
||||
|
||||
@classmethod
|
||||
def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:
|
||||
|
||||
@ -97,11 +97,6 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus:
|
||||
class RagPipelinePendingData(BaseModel):
|
||||
import_mode: str
|
||||
yaml_content: str
|
||||
name: str | None
|
||||
description: str | None
|
||||
icon_type: str | None
|
||||
icon: str | None
|
||||
icon_background: str | None
|
||||
pipeline_id: str | None
|
||||
|
||||
|
||||
@ -302,10 +297,6 @@ class RagPipelineDslService:
|
||||
dataset.runtime_mode = "rag_pipeline"
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
if knowledge_configuration.index_method.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore
|
||||
)
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.filter(
|
||||
@ -445,10 +436,28 @@ class RagPipelineDslService:
|
||||
dataset.runtime_mode = "rag_pipeline"
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
if knowledge_configuration.index_method.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.filter(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
|
||||
DatasetCollectionBinding.model_name
|
||||
== knowledge_configuration.index_method.embedding_setting.embedding_model_name,
|
||||
DatasetCollectionBinding.type == "dataset",
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not dataset_collection_binding:
|
||||
dataset_collection_binding = DatasetCollectionBinding(
|
||||
provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
|
||||
model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name,
|
||||
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
|
||||
type="dataset",
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.commit()
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = (
|
||||
@ -602,7 +611,6 @@ class RagPipelineDslService:
|
||||
rag_pipeline_service.sync_draft_workflow(
|
||||
pipeline=pipeline,
|
||||
graph=workflow_data.get("graph", {}),
|
||||
features=workflow_data.get("features", {}),
|
||||
unique_hash=unique_hash,
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
|
||||
Reference in New Issue
Block a user