This commit is contained in:
jyong
2025-05-28 17:56:04 +08:00
parent 5fc2bc58a9
commit 7f59ffe7af
32 changed files with 680 additions and 202 deletions

View File

@ -51,7 +51,10 @@ from services.entities.knowledge_entities.knowledge_entities import (
RetrievalModel,
SegmentUpdateArgs,
)
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeBaseUpdateConfiguration,
RagPipelineDatasetCreateEntity,
)
from services.errors.account import InvalidActionError, NoPermissionError
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
from services.errors.dataset import DatasetNameDuplicateError
@ -59,11 +62,11 @@ from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureModel, FeatureService
from services.rag_pipeline.rag_pipeline_dsl_service import ImportMode, RagPipelineDslService, RagPipelineImportInfo
from services.tag_service import TagService
from services.vector_service import VectorService
from tasks.batch_clean_document_task import batch_clean_document_task
from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_index_update_task import deal_dataset_index_update_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
@ -278,47 +281,6 @@ class DatasetService:
db.session.commit()
return dataset
@staticmethod
def create_rag_pipeline_dataset(
tenant_id: str,
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
):
# check if dataset name already exists
if (
db.session.query(Dataset)
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
):
raise DatasetNameDuplicateError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
dataset = Dataset(
name=rag_pipeline_dataset_create_entity.name,
description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag-pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
)
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
account = cast(Account, current_user)
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=dataset,
)
return {
"id": rag_pipeline_import_info.id,
"dataset_id": dataset.id,
"pipeline_id": rag_pipeline_import_info.pipeline_id,
"status": rag_pipeline_import_info.status,
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
"current_dsl_version": rag_pipeline_import_info.current_dsl_version,
"error": rag_pipeline_import_info.error,
}
@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
@ -529,6 +491,130 @@ class DatasetService:
if action:
deal_dataset_vector_index_task.delay(dataset_id, action)
return dataset
@staticmethod
def update_rag_pipeline_dataset_settings(session: Session,
dataset: Dataset,
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
has_published: bool = False):
if not has_published:
dataset.chunk_structure = knowledge_base_setting.chunk_structure
index_method = knowledge_base_setting.index_method
dataset.indexing_technique = index_method.indexing_technique
if index_method == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=index_method.embedding_setting.embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=index_method.embedding_setting.embedding_model_name,
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
dataset.collection_binding_id = dataset_collection_binding.id
elif index_method == "economy":
dataset.keyword_number = index_method.economy_setting.keyword_number
else:
raise ValueError("Invalid index method")
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
session.add(dataset)
else:
if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure:
raise ValueError("Chunk structure is not allowed to be updated.")
action = None
if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique:
# if update indexing_technique
if knowledge_base_setting.index_method.indexing_technique == "economy":
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
elif knowledge_base_setting.index_method.indexing_technique == "high_quality":
action = "add"
# get embedding model setting
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
dataset.collection_binding_id = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
else:
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
# Skip embedding model checks if not provided in the update request
if dataset.indexing_technique == "high_quality":
skip_embedding_update = False
try:
# Handle existing model provider
plugin_model_provider = dataset.embedding_model_provider
plugin_model_provider_str = None
if plugin_model_provider:
plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))
# Handle new model provider from request
new_plugin_model_provider = knowledge_base_setting.index_method.embedding_setting.embedding_provider_name
new_plugin_model_provider_str = None
if new_plugin_model_provider:
new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
# Only update embedding model if both values are provided and different from current
if (
plugin_model_provider_str != new_plugin_model_provider_str
or knowledge_base_setting.index_method.embedding_setting.embedding_model_name != dataset.embedding_model
):
action = "update"
model_manager = ModelManager()
try:
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
)
except ProviderTokenNotInitError:
# If we can't get the embedding model, skip updating it
# and keep the existing settings if available
# Skip the rest of the embedding model update
skip_embedding_update = True
if not skip_embedding_update:
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
)
dataset.collection_binding_id = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
elif dataset.indexing_technique == "economy":
if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number:
dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
session.add(dataset)
session.commit()
if action:
deal_dataset_index_update_task.delay(dataset.id, action)
@staticmethod
def delete_dataset(dataset_id, user):