mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
r2
This commit is contained in:
@ -51,7 +51,10 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
RetrievalModel,
|
||||
SegmentUpdateArgs,
|
||||
)
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
KnowledgeBaseUpdateConfiguration,
|
||||
RagPipelineDatasetCreateEntity,
|
||||
)
|
||||
from services.errors.account import InvalidActionError, NoPermissionError
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
@ -59,11 +62,11 @@ from services.errors.document import DocumentIndexingError
|
||||
from services.errors.file import FileNotExistsError
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureModel, FeatureService
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import ImportMode, RagPipelineDslService, RagPipelineImportInfo
|
||||
from services.tag_service import TagService
|
||||
from services.vector_service import VectorService
|
||||
from tasks.batch_clean_document_task import batch_clean_document_task
|
||||
from tasks.clean_notion_document_task import clean_notion_document_task
|
||||
from tasks.deal_dataset_index_update_task import deal_dataset_index_update_task
|
||||
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||
@ -278,47 +281,6 @@ class DatasetService:
|
||||
db.session.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_dataset(
|
||||
tenant_id: str,
|
||||
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
):
|
||||
raise DatasetNameDuplicateError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
name=rag_pipeline_dataset_create_entity.name,
|
||||
description=rag_pipeline_dataset_create_entity.description,
|
||||
permission=rag_pipeline_dataset_create_entity.permission,
|
||||
provider="vendor",
|
||||
runtime_mode="rag-pipeline",
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
account = cast(Account, current_user)
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=dataset,
|
||||
)
|
||||
return {
|
||||
"id": rag_pipeline_import_info.id,
|
||||
"dataset_id": dataset.id,
|
||||
"pipeline_id": rag_pipeline_import_info.pipeline_id,
|
||||
"status": rag_pipeline_import_info.status,
|
||||
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
|
||||
"current_dsl_version": rag_pipeline_import_info.current_dsl_version,
|
||||
"error": rag_pipeline_import_info.error,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id) -> Optional[Dataset]:
|
||||
@ -529,6 +491,130 @@ class DatasetService:
|
||||
if action:
|
||||
deal_dataset_vector_index_task.delay(dataset_id, action)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def update_rag_pipeline_dataset_settings(session: Session,
|
||||
dataset: Dataset,
|
||||
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
|
||||
has_published: bool = False):
|
||||
if not has_published:
|
||||
dataset.chunk_structure = knowledge_base_setting.chunk_structure
|
||||
index_method = knowledge_base_setting.index_method
|
||||
dataset.indexing_technique = index_method.indexing_technique
|
||||
if index_method == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=index_method.embedding_setting.embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=index_method.embedding_setting.embedding_model_name,
|
||||
)
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
elif index_method == "economy":
|
||||
dataset.keyword_number = index_method.economy_setting.keyword_number
|
||||
else:
|
||||
raise ValueError("Invalid index method")
|
||||
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
|
||||
session.add(dataset)
|
||||
else:
|
||||
if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure:
|
||||
raise ValueError("Chunk structure is not allowed to be updated.")
|
||||
action = None
|
||||
if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique:
|
||||
# if update indexing_technique
|
||||
if knowledge_base_setting.index_method.indexing_technique == "economy":
|
||||
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
|
||||
elif knowledge_base_setting.index_method.indexing_technique == "high_quality":
|
||||
action = "add"
|
||||
# get embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
|
||||
)
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
else:
|
||||
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
|
||||
# Skip embedding model checks if not provided in the update request
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
skip_embedding_update = False
|
||||
try:
|
||||
# Handle existing model provider
|
||||
plugin_model_provider = dataset.embedding_model_provider
|
||||
plugin_model_provider_str = None
|
||||
if plugin_model_provider:
|
||||
plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))
|
||||
|
||||
# Handle new model provider from request
|
||||
new_plugin_model_provider = knowledge_base_setting.index_method.embedding_setting.embedding_provider_name
|
||||
new_plugin_model_provider_str = None
|
||||
if new_plugin_model_provider:
|
||||
new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
|
||||
|
||||
# Only update embedding model if both values are provided and different from current
|
||||
if (
|
||||
plugin_model_provider_str != new_plugin_model_provider_str
|
||||
or knowledge_base_setting.index_method.embedding_setting.embedding_model_name != dataset.embedding_model
|
||||
):
|
||||
action = "update"
|
||||
model_manager = ModelManager()
|
||||
try:
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
# If we can't get the embedding model, skip updating it
|
||||
# and keep the existing settings if available
|
||||
# Skip the rest of the embedding model update
|
||||
skip_embedding_update = True
|
||||
if not skip_embedding_update:
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
elif dataset.indexing_technique == "economy":
|
||||
if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number:
|
||||
dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number
|
||||
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
if action:
|
||||
deal_dataset_index_update_task.delay(dataset.id, action)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def delete_dataset(dataset_id, user):
|
||||
|
||||
@ -4,29 +4,12 @@ from typing import Optional
|
||||
from flask_login import current_user
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core import datasource
|
||||
from core.datasource.__base import datasource_provider
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
|
||||
from core.helper import encrypter
|
||||
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
|
||||
from core.model_runtime.entities.provider_entities import FormType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.provider_manager import ProviderManager
|
||||
from extensions.ext_database import db
|
||||
from models.oauth import DatasourceProvider
|
||||
from models.provider import ProviderType
|
||||
from services.entities.model_provider_entities import (
|
||||
CustomConfigurationResponse,
|
||||
CustomConfigurationStatus,
|
||||
DefaultModelResponse,
|
||||
ModelWithProviderEntityResponse,
|
||||
ProviderResponse,
|
||||
ProviderWithModelsResponse,
|
||||
SimpleProviderEntityResponse,
|
||||
SystemConfigurationResponse,
|
||||
)
|
||||
from extensions.database import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -115,16 +98,26 @@ class DatasourceProviderService:
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param datasource_name: datasource name
|
||||
:param plugin_id: plugin id
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
||||
datasource_provider: DatasourceProvider | None = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id).first()
|
||||
if not datasource_provider:
|
||||
return None
|
||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
copy_credentials = encrypted_credentials.copy()
|
||||
for key, value in copy_credentials.items():
|
||||
if key in credential_secret_variables:
|
||||
copy_credentials[key] = encrypter.obfuscated_token(value)
|
||||
|
||||
return copy_credentials
|
||||
|
||||
|
||||
def remove_datasource_credentials(self,
|
||||
@ -136,11 +129,9 @@ class DatasourceProviderService:
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param datasource_name: datasource name
|
||||
:param plugin_id: plugin id
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id).first()
|
||||
|
||||
@ -111,3 +111,12 @@ class KnowledgeConfiguration(BaseModel):
|
||||
chunk_structure: str
|
||||
index_method: IndexMethod
|
||||
retrieval_setting: RetrievalSetting
|
||||
|
||||
|
||||
class KnowledgeBaseUpdateConfiguration(BaseModel):
|
||||
"""
|
||||
Knowledge Base Update Configuration.
|
||||
"""
|
||||
index_method: IndexMethod
|
||||
chunk_structure: str
|
||||
retrieval_setting: RetrievalSetting
|
||||
@ -69,9 +69,9 @@ class PipelineGenerateService:
|
||||
@classmethod
|
||||
def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
|
||||
return WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_loop_generate(
|
||||
app_model=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
return PipelineGenerator.convert_to_event_stream(
|
||||
PipelineGenerator().single_loop_generate(
|
||||
pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -36,7 +36,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
|
||||
recommended_pipelines_results = []
|
||||
for pipeline_built_in_template in pipeline_built_in_templates:
|
||||
pipeline_model: Pipeline = pipeline_built_in_template.pipeline
|
||||
pipeline_model: Pipeline | None = pipeline_built_in_template.pipeline
|
||||
if not pipeline_model:
|
||||
continue
|
||||
|
||||
recommended_pipeline_result = {
|
||||
"id": pipeline_built_in_template.id,
|
||||
@ -48,7 +50,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
"privacy_policy": pipeline_built_in_template.privacy_policy,
|
||||
"position": pipeline_built_in_template.position,
|
||||
}
|
||||
dataset: Dataset = pipeline_model.dataset
|
||||
dataset: Dataset | None = pipeline_model.dataset
|
||||
if dataset:
|
||||
recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure
|
||||
recommended_pipelines_results.append(recommended_pipeline_result)
|
||||
@ -72,15 +74,19 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
if not pipeline_template:
|
||||
return None
|
||||
|
||||
# get app detail
|
||||
# get pipeline detail
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first()
|
||||
if not pipeline or not pipeline.is_public:
|
||||
return None
|
||||
|
||||
dataset: Dataset | None = pipeline.dataset
|
||||
if not dataset:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": pipeline.id,
|
||||
"name": pipeline.name,
|
||||
"icon": pipeline.icon,
|
||||
"mode": pipeline.mode,
|
||||
"icon": pipeline_template.icon,
|
||||
"chunk_structure": dataset.chunk_structure,
|
||||
"export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline),
|
||||
}
|
||||
|
||||
@ -46,7 +46,8 @@ from models.workflow import (
|
||||
WorkflowRun,
|
||||
WorkflowType,
|
||||
)
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, PipelineTemplateInfoEntity
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
|
||||
|
||||
@ -261,8 +262,7 @@ class RagPipelineService:
|
||||
session: Session,
|
||||
pipeline: Pipeline,
|
||||
account: Account,
|
||||
marked_name: str = "",
|
||||
marked_comment: str = "",
|
||||
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
|
||||
) -> Workflow:
|
||||
draft_workflow_stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
@ -282,18 +282,25 @@ class RagPipelineService:
|
||||
graph=draft_workflow.graph,
|
||||
features=draft_workflow.features,
|
||||
created_by=account.id,
|
||||
environment_variables=draft_workflow.environment_variables,
|
||||
environment_variables=draft_workflow.environment_variables,
|
||||
conversation_variables=draft_workflow.conversation_variables,
|
||||
marked_name=marked_name,
|
||||
marked_comment=marked_comment,
|
||||
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
|
||||
marked_name="",
|
||||
marked_comment="",
|
||||
)
|
||||
|
||||
# commit db session changes
|
||||
session.add(workflow)
|
||||
|
||||
# trigger app workflow events TODO
|
||||
# app_published_workflow_was_updated.send(pipeline, published_workflow=workflow)
|
||||
|
||||
# update dataset
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
DatasetService.update_rag_pipeline_dataset_settings(
|
||||
session=session,
|
||||
dataset=dataset,
|
||||
knowledge_base_setting=knowledge_base_setting,
|
||||
has_published=pipeline.is_published
|
||||
)
|
||||
# return new workflow
|
||||
return workflow
|
||||
|
||||
|
||||
@ -4,13 +4,14 @@ import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml # type: ignore
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from flask_login import current_user
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
@ -31,7 +32,10 @@ from factories import variable_factory
|
||||
from models import Account
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
|
||||
from models.workflow import Workflow
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
KnowledgeConfiguration,
|
||||
RagPipelineDatasetCreateEntity,
|
||||
)
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
@ -540,9 +544,6 @@ class RagPipelineDslService:
|
||||
# Update existing pipeline
|
||||
pipeline.name = pipeline_data.get("name", pipeline.name)
|
||||
pipeline.description = pipeline_data.get("description", pipeline.description)
|
||||
pipeline.icon_type = icon_type
|
||||
pipeline.icon = icon
|
||||
pipeline.icon_background = pipeline_data.get("icon_background", pipeline.icon_background)
|
||||
pipeline.updated_by = account.id
|
||||
else:
|
||||
if account.current_tenant_id is None:
|
||||
@ -554,12 +555,6 @@ class RagPipelineDslService:
|
||||
pipeline.tenant_id = account.current_tenant_id
|
||||
pipeline.name = pipeline_data.get("name", "")
|
||||
pipeline.description = pipeline_data.get("description", "")
|
||||
pipeline.icon_type = icon_type
|
||||
pipeline.icon = icon
|
||||
pipeline.icon_background = pipeline_data.get("icon_background", "#FFFFFF")
|
||||
pipeline.enable_site = True
|
||||
pipeline.enable_api = True
|
||||
pipeline.use_icon_as_answer_icon = pipeline_data.get("use_icon_as_answer_icon", False)
|
||||
pipeline.created_by = account.id
|
||||
pipeline.updated_by = account.id
|
||||
|
||||
@ -674,26 +669,6 @@ class RagPipelineDslService:
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _append_model_config_export_data(cls, export_data: dict, pipeline: Pipeline) -> None:
|
||||
"""
|
||||
Append model config export data
|
||||
:param export_data: export data
|
||||
:param pipeline: Pipeline instance
|
||||
"""
|
||||
app_model_config = pipeline.app_model_config
|
||||
if not app_model_config:
|
||||
raise ValueError("Missing app configuration, please check.")
|
||||
|
||||
export_data["model_config"] = app_model_config.to_dict()
|
||||
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
|
||||
export_data["dependencies"] = [
|
||||
jsonable_encoder(d.model_dump())
|
||||
for d in DependenciesAnalysisService.generate_dependencies(
|
||||
tenant_id=pipeline.tenant_id, dependencies=dependencies
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
|
||||
"""
|
||||
@ -863,3 +838,46 @@ class RagPipelineDslService:
|
||||
return pt.decode()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_dataset(
|
||||
tenant_id: str,
|
||||
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
):
|
||||
raise ValueError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
name=rag_pipeline_dataset_create_entity.name,
|
||||
description=rag_pipeline_dataset_create_entity.description,
|
||||
permission=rag_pipeline_dataset_create_entity.permission,
|
||||
provider="vendor",
|
||||
runtime_mode="rag-pipeline",
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
account = cast(Account, current_user)
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=dataset,
|
||||
)
|
||||
return {
|
||||
"id": rag_pipeline_import_info.id,
|
||||
"dataset_id": dataset.id,
|
||||
"pipeline_id": rag_pipeline_import_info.pipeline_id,
|
||||
"status": rag_pipeline_import_info.status,
|
||||
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
|
||||
"current_dsl_version": rag_pipeline_import_info.current_dsl_version,
|
||||
"error": rag_pipeline_import_info.error,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user