This commit is contained in:
jyong
2025-05-23 19:30:48 +08:00
parent 70d2c78176
commit 6d547447d3
11 changed files with 157 additions and 72 deletions

View File

@ -6,7 +6,7 @@ import random
import time
import uuid
from collections import Counter
from typing import Any, Optional
from typing import Any, Optional, cast
from flask_login import current_user
from sqlalchemy import func, select
@ -298,13 +298,14 @@ class DatasetService:
description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag_pipeline",
runtime_mode="rag-pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info,
)
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
account = cast(Account, current_user)
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
account=current_user,
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=dataset,

View File

@ -59,12 +59,12 @@ class RagPipelineService:
if not result.get("pipeline_templates") and language != "en-US":
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
return result.get("pipeline_templates")
return [PipelineBuiltInTemplate(**template) for template in result.get("pipeline_templates", [])]
else:
mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result = retrieval_instance.get_pipeline_templates(language)
return result.get("pipeline_templates")
return [PipelineCustomizedTemplate(**template) for template in result.get("pipeline_templates", [])]
@classmethod
def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:

View File

@ -97,11 +97,6 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus:
class RagPipelinePendingData(BaseModel):
import_mode: str
yaml_content: str
name: str | None
description: str | None
icon_type: str | None
icon: str | None
icon_background: str | None
pipeline_id: str | None
@ -302,10 +297,6 @@ class RagPipelineDslService:
dataset.runtime_mode = "rag_pipeline"
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.index_method.indexing_technique == "high_quality":
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore
knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore
)
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(
@ -445,10 +436,28 @@ class RagPipelineDslService:
dataset.runtime_mode = "rag_pipeline"
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.index_method.indexing_technique == "high_quality":
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore
knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(
DatasetCollectionBinding.provider_name
== knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
DatasetCollectionBinding.model_name
== knowledge_configuration.index_method.embedding_setting.embedding_model_name,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
.first()
)
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
)
db.session.add(dataset_collection_binding)
db.session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = (
@ -602,7 +611,6 @@ class RagPipelineDslService:
rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline,
graph=workflow_data.get("graph", {}),
features=workflow_data.get("features", {}),
unique_hash=unique_hash,
account=account,
environment_variables=environment_variables,