mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
r2
This commit is contained in:
@ -6,7 +6,7 @@ import random
|
||||
import time
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func, select
|
||||
@ -298,13 +298,14 @@ class DatasetService:
|
||||
description=rag_pipeline_dataset_create_entity.description,
|
||||
permission=rag_pipeline_dataset_create_entity.permission,
|
||||
provider="vendor",
|
||||
runtime_mode="rag_pipeline",
|
||||
runtime_mode="rag-pipeline",
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
account = cast(Account, current_user)
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
|
||||
account=current_user,
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=dataset,
|
||||
|
||||
@ -59,12 +59,12 @@ class RagPipelineService:
|
||||
if not result.get("pipeline_templates") and language != "en-US":
|
||||
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
|
||||
result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
|
||||
return result.get("pipeline_templates")
|
||||
return [PipelineBuiltInTemplate(**template) for template in result.get("pipeline_templates", [])]
|
||||
else:
|
||||
mode = "customized"
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
result = retrieval_instance.get_pipeline_templates(language)
|
||||
return result.get("pipeline_templates")
|
||||
return [PipelineCustomizedTemplate(**template) for template in result.get("pipeline_templates", [])]
|
||||
|
||||
@classmethod
|
||||
def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:
|
||||
|
||||
@ -97,11 +97,6 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus:
|
||||
class RagPipelinePendingData(BaseModel):
|
||||
import_mode: str
|
||||
yaml_content: str
|
||||
name: str | None
|
||||
description: str | None
|
||||
icon_type: str | None
|
||||
icon: str | None
|
||||
icon_background: str | None
|
||||
pipeline_id: str | None
|
||||
|
||||
|
||||
@ -302,10 +297,6 @@ class RagPipelineDslService:
|
||||
dataset.runtime_mode = "rag_pipeline"
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
if knowledge_configuration.index_method.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore
|
||||
)
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.filter(
|
||||
@ -445,10 +436,28 @@ class RagPipelineDslService:
|
||||
dataset.runtime_mode = "rag_pipeline"
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
if knowledge_configuration.index_method.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.filter(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
|
||||
DatasetCollectionBinding.model_name
|
||||
== knowledge_configuration.index_method.embedding_setting.embedding_model_name,
|
||||
DatasetCollectionBinding.type == "dataset",
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not dataset_collection_binding:
|
||||
dataset_collection_binding = DatasetCollectionBinding(
|
||||
provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
|
||||
model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name,
|
||||
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
|
||||
type="dataset",
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.commit()
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = (
|
||||
@ -602,7 +611,6 @@ class RagPipelineDslService:
|
||||
rag_pipeline_service.sync_draft_workflow(
|
||||
pipeline=pipeline,
|
||||
graph=workflow_data.get("graph", {}),
|
||||
features=workflow_data.get("features", {}),
|
||||
unique_hash=unique_hash,
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
|
||||
Reference in New Issue
Block a user