mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
r2
This commit is contained in:
@ -51,7 +51,10 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
RetrievalModel,
|
||||
SegmentUpdateArgs,
|
||||
)
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
KnowledgeBaseUpdateConfiguration,
|
||||
RagPipelineDatasetCreateEntity,
|
||||
)
|
||||
from services.errors.account import InvalidActionError, NoPermissionError
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
@ -59,11 +62,11 @@ from services.errors.document import DocumentIndexingError
|
||||
from services.errors.file import FileNotExistsError
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureModel, FeatureService
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import ImportMode, RagPipelineDslService, RagPipelineImportInfo
|
||||
from services.tag_service import TagService
|
||||
from services.vector_service import VectorService
|
||||
from tasks.batch_clean_document_task import batch_clean_document_task
|
||||
from tasks.clean_notion_document_task import clean_notion_document_task
|
||||
from tasks.deal_dataset_index_update_task import deal_dataset_index_update_task
|
||||
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||
@ -278,47 +281,6 @@ class DatasetService:
|
||||
db.session.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_dataset(
|
||||
tenant_id: str,
|
||||
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
):
|
||||
raise DatasetNameDuplicateError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
name=rag_pipeline_dataset_create_entity.name,
|
||||
description=rag_pipeline_dataset_create_entity.description,
|
||||
permission=rag_pipeline_dataset_create_entity.permission,
|
||||
provider="vendor",
|
||||
runtime_mode="rag-pipeline",
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
account = cast(Account, current_user)
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=dataset,
|
||||
)
|
||||
return {
|
||||
"id": rag_pipeline_import_info.id,
|
||||
"dataset_id": dataset.id,
|
||||
"pipeline_id": rag_pipeline_import_info.pipeline_id,
|
||||
"status": rag_pipeline_import_info.status,
|
||||
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
|
||||
"current_dsl_version": rag_pipeline_import_info.current_dsl_version,
|
||||
"error": rag_pipeline_import_info.error,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id) -> Optional[Dataset]:
|
||||
@ -529,6 +491,130 @@ class DatasetService:
|
||||
if action:
|
||||
deal_dataset_vector_index_task.delay(dataset_id, action)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def update_rag_pipeline_dataset_settings(session: Session,
|
||||
dataset: Dataset,
|
||||
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
|
||||
has_published: bool = False):
|
||||
if not has_published:
|
||||
dataset.chunk_structure = knowledge_base_setting.chunk_structure
|
||||
index_method = knowledge_base_setting.index_method
|
||||
dataset.indexing_technique = index_method.indexing_technique
|
||||
if index_method == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=index_method.embedding_setting.embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=index_method.embedding_setting.embedding_model_name,
|
||||
)
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
elif index_method == "economy":
|
||||
dataset.keyword_number = index_method.economy_setting.keyword_number
|
||||
else:
|
||||
raise ValueError("Invalid index method")
|
||||
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
|
||||
session.add(dataset)
|
||||
else:
|
||||
if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure:
|
||||
raise ValueError("Chunk structure is not allowed to be updated.")
|
||||
action = None
|
||||
if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique:
|
||||
# if update indexing_technique
|
||||
if knowledge_base_setting.index_method.indexing_technique == "economy":
|
||||
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
|
||||
elif knowledge_base_setting.index_method.indexing_technique == "high_quality":
|
||||
action = "add"
|
||||
# get embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
|
||||
)
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
else:
|
||||
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
|
||||
# Skip embedding model checks if not provided in the update request
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
skip_embedding_update = False
|
||||
try:
|
||||
# Handle existing model provider
|
||||
plugin_model_provider = dataset.embedding_model_provider
|
||||
plugin_model_provider_str = None
|
||||
if plugin_model_provider:
|
||||
plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))
|
||||
|
||||
# Handle new model provider from request
|
||||
new_plugin_model_provider = knowledge_base_setting.index_method.embedding_setting.embedding_provider_name
|
||||
new_plugin_model_provider_str = None
|
||||
if new_plugin_model_provider:
|
||||
new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
|
||||
|
||||
# Only update embedding model if both values are provided and different from current
|
||||
if (
|
||||
plugin_model_provider_str != new_plugin_model_provider_str
|
||||
or knowledge_base_setting.index_method.embedding_setting.embedding_model_name != dataset.embedding_model
|
||||
):
|
||||
action = "update"
|
||||
model_manager = ModelManager()
|
||||
try:
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
# If we can't get the embedding model, skip updating it
|
||||
# and keep the existing settings if available
|
||||
# Skip the rest of the embedding model update
|
||||
skip_embedding_update = True
|
||||
if not skip_embedding_update:
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
elif dataset.indexing_technique == "economy":
|
||||
if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number:
|
||||
dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number
|
||||
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
if action:
|
||||
deal_dataset_index_update_task.delay(dataset.id, action)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def delete_dataset(dataset_id, user):
|
||||
|
||||
Reference in New Issue
Block a user