mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
Merge branch 'main' into feat/r2
# Conflicts: # api/core/plugin/impl/oauth.py # api/core/workflow/entities/variable_pool.py # api/models/workflow.py # api/services/dataset_service.py
This commit is contained in:
@ -64,6 +64,7 @@ from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureModel, FeatureService
|
||||
from services.tag_service import TagService
|
||||
from services.vector_service import VectorService
|
||||
from tasks.add_document_to_index_task import add_document_to_index_task
|
||||
from tasks.batch_clean_document_task import batch_clean_document_task
|
||||
from tasks.clean_notion_document_task import clean_notion_document_task
|
||||
from tasks.deal_dataset_index_update_task import deal_dataset_index_update_task
|
||||
@ -76,6 +77,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
|
||||
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
|
||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||
from tasks.retry_document_indexing_task import retry_document_indexing_task
|
||||
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
|
||||
|
||||
@ -323,188 +325,343 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def update_dataset(dataset_id, data, user):
|
||||
"""
|
||||
Update dataset configuration and settings.
|
||||
|
||||
Args:
|
||||
dataset_id: The unique identifier of the dataset to update
|
||||
data: Dictionary containing the update data
|
||||
user: The user performing the update operation
|
||||
|
||||
Returns:
|
||||
Dataset: The updated dataset object
|
||||
|
||||
Raises:
|
||||
ValueError: If dataset not found or validation fails
|
||||
NoPermissionError: If user lacks permission to update the dataset
|
||||
"""
|
||||
# Retrieve and validate dataset existence
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
if dataset.provider == "external":
|
||||
external_retrieval_model = data.get("external_retrieval_model", None)
|
||||
if external_retrieval_model:
|
||||
dataset.retrieval_model = external_retrieval_model
|
||||
dataset.name = data.get("name", dataset.name)
|
||||
# check if dataset name is exists
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
# check if dataset name is exists
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
.filter(
|
||||
Dataset.id != dataset_id,
|
||||
Dataset.name == dataset.name,
|
||||
Dataset.tenant_id == dataset.tenant_id,
|
||||
)
|
||||
.first()
|
||||
):
|
||||
raise ValueError("Dataset name already exists")
|
||||
dataset.description = data.get("description", "")
|
||||
permission = data.get("permission")
|
||||
if permission:
|
||||
dataset.permission = permission
|
||||
external_knowledge_id = data.get("external_knowledge_id", None)
|
||||
db.session.add(dataset)
|
||||
if not external_knowledge_id:
|
||||
raise ValueError("External knowledge id is required.")
|
||||
external_knowledge_api_id = data.get("external_knowledge_api_id", None)
|
||||
if not external_knowledge_api_id:
|
||||
raise ValueError("External knowledge api id is required.")
|
||||
Dataset.id != dataset_id,
|
||||
Dataset.name == data.get("name", dataset.name),
|
||||
Dataset.tenant_id == dataset.tenant_id,
|
||||
)
|
||||
.first()
|
||||
):
|
||||
raise ValueError("Dataset name already exists")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
external_knowledge_binding = (
|
||||
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
|
||||
)
|
||||
# Verify user has permission to update this dataset
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("External knowledge binding not found.")
|
||||
|
||||
if (
|
||||
external_knowledge_binding.external_knowledge_id != external_knowledge_id
|
||||
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
|
||||
):
|
||||
external_knowledge_binding.external_knowledge_id = external_knowledge_id
|
||||
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
|
||||
db.session.add(external_knowledge_binding)
|
||||
db.session.commit()
|
||||
# Handle external dataset updates
|
||||
if dataset.provider == "external":
|
||||
return DatasetService._update_external_dataset(dataset, data, user)
|
||||
else:
|
||||
data.pop("partial_member_list", None)
|
||||
data.pop("external_knowledge_api_id", None)
|
||||
data.pop("external_knowledge_id", None)
|
||||
data.pop("external_retrieval_model", None)
|
||||
filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"}
|
||||
action = None
|
||||
if dataset.indexing_technique != data["indexing_technique"]:
|
||||
# if update indexing_technique
|
||||
if data["indexing_technique"] == "economy":
|
||||
action = "remove"
|
||||
filtered_data["embedding_model"] = None
|
||||
filtered_data["embedding_model_provider"] = None
|
||||
filtered_data["collection_binding_id"] = None
|
||||
elif data["indexing_technique"] == "high_quality":
|
||||
action = "add"
|
||||
# get embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=data["embedding_model_provider"],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=data["embedding_model"],
|
||||
)
|
||||
filtered_data["embedding_model"] = embedding_model.model
|
||||
filtered_data["embedding_model_provider"] = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
else:
|
||||
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
|
||||
# Skip embedding model checks if not provided in the update request
|
||||
if (
|
||||
"embedding_model_provider" not in data
|
||||
or "embedding_model" not in data
|
||||
or not data.get("embedding_model_provider")
|
||||
or not data.get("embedding_model")
|
||||
):
|
||||
# If the dataset already has embedding model settings, use those
|
||||
if dataset.embedding_model_provider and dataset.embedding_model:
|
||||
# Keep existing values
|
||||
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
|
||||
filtered_data["embedding_model"] = dataset.embedding_model
|
||||
# If collection_binding_id exists, keep it too
|
||||
if dataset.collection_binding_id:
|
||||
filtered_data["collection_binding_id"] = dataset.collection_binding_id
|
||||
# Otherwise, don't try to update embedding model settings at all
|
||||
# Remove these fields from filtered_data if they exist but are None/empty
|
||||
if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]:
|
||||
del filtered_data["embedding_model_provider"]
|
||||
if "embedding_model" in filtered_data and not filtered_data["embedding_model"]:
|
||||
del filtered_data["embedding_model"]
|
||||
else:
|
||||
skip_embedding_update = False
|
||||
try:
|
||||
# Handle existing model provider
|
||||
plugin_model_provider = dataset.embedding_model_provider
|
||||
plugin_model_provider_str = None
|
||||
if plugin_model_provider:
|
||||
plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))
|
||||
return DatasetService._update_internal_dataset(dataset, data, user)
|
||||
|
||||
# Handle new model provider from request
|
||||
new_plugin_model_provider = data["embedding_model_provider"]
|
||||
new_plugin_model_provider_str = None
|
||||
if new_plugin_model_provider:
|
||||
new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
|
||||
@staticmethod
|
||||
def _update_external_dataset(dataset, data, user):
|
||||
"""
|
||||
Update external dataset configuration.
|
||||
|
||||
# Only update embedding model if both values are provided and different from current
|
||||
if (
|
||||
plugin_model_provider_str != new_plugin_model_provider_str
|
||||
or data["embedding_model"] != dataset.embedding_model
|
||||
):
|
||||
action = "update"
|
||||
model_manager = ModelManager()
|
||||
try:
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=data["embedding_model_provider"],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=data["embedding_model"],
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
# If we can't get the embedding model, skip updating it
|
||||
# and keep the existing settings if available
|
||||
if dataset.embedding_model_provider and dataset.embedding_model:
|
||||
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
|
||||
filtered_data["embedding_model"] = dataset.embedding_model
|
||||
if dataset.collection_binding_id:
|
||||
filtered_data["collection_binding_id"] = dataset.collection_binding_id
|
||||
# Skip the rest of the embedding model update
|
||||
skip_embedding_update = True
|
||||
if not skip_embedding_update:
|
||||
filtered_data["embedding_model"] = embedding_model.model
|
||||
filtered_data["embedding_model_provider"] = embedding_model.provider
|
||||
dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
)
|
||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
Args:
|
||||
dataset: The dataset object to update
|
||||
data: Update data dictionary
|
||||
user: User performing the update
|
||||
|
||||
filtered_data["updated_by"] = user.id
|
||||
filtered_data["updated_at"] = datetime.datetime.now()
|
||||
Returns:
|
||||
Dataset: Updated dataset object
|
||||
"""
|
||||
# Update retrieval model if provided
|
||||
external_retrieval_model = data.get("external_retrieval_model", None)
|
||||
if external_retrieval_model:
|
||||
dataset.retrieval_model = external_retrieval_model
|
||||
|
||||
# update Retrieval model
|
||||
filtered_data["retrieval_model"] = data["retrieval_model"]
|
||||
# Update basic dataset properties
|
||||
dataset.name = data.get("name", dataset.name)
|
||||
dataset.description = data.get("description", dataset.description)
|
||||
|
||||
# update icon info
|
||||
if data.get("icon_info"):
|
||||
filtered_data["icon_info"] = data.get("icon_info")
|
||||
db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data)
|
||||
# Update permission if provided
|
||||
permission = data.get("permission")
|
||||
if permission:
|
||||
dataset.permission = permission
|
||||
|
||||
# Validate and update external knowledge configuration
|
||||
external_knowledge_id = data.get("external_knowledge_id", None)
|
||||
external_knowledge_api_id = data.get("external_knowledge_api_id", None)
|
||||
|
||||
if not external_knowledge_id:
|
||||
raise ValueError("External knowledge id is required.")
|
||||
if not external_knowledge_api_id:
|
||||
raise ValueError("External knowledge api id is required.")
|
||||
# Update metadata fields
|
||||
dataset.updated_by = user.id if user else None
|
||||
dataset.updated_at = datetime.datetime.utcnow()
|
||||
db.session.add(dataset)
|
||||
|
||||
# Update external knowledge binding
|
||||
DatasetService._update_external_knowledge_binding(dataset.id, external_knowledge_id, external_knowledge_api_id)
|
||||
|
||||
# Commit changes to database
|
||||
db.session.commit()
|
||||
|
||||
db.session.commit()
|
||||
if action:
|
||||
deal_dataset_vector_index_task.delay(dataset_id, action)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id):
|
||||
"""
|
||||
Update external knowledge binding configuration.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset identifier
|
||||
external_knowledge_id: External knowledge identifier
|
||||
external_knowledge_api_id: External knowledge API identifier
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
external_knowledge_binding = (
|
||||
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
|
||||
)
|
||||
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("External knowledge binding not found.")
|
||||
|
||||
# Update binding if values have changed
|
||||
if (
|
||||
external_knowledge_binding.external_knowledge_id != external_knowledge_id
|
||||
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
|
||||
):
|
||||
external_knowledge_binding.external_knowledge_id = external_knowledge_id
|
||||
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
|
||||
db.session.add(external_knowledge_binding)
|
||||
|
||||
@staticmethod
|
||||
def _update_internal_dataset(dataset, data, user):
|
||||
"""
|
||||
Update internal dataset configuration.
|
||||
|
||||
Args:
|
||||
dataset: The dataset object to update
|
||||
data: Update data dictionary
|
||||
user: User performing the update
|
||||
|
||||
Returns:
|
||||
Dataset: Updated dataset object
|
||||
"""
|
||||
# Remove external-specific fields from update data
|
||||
data.pop("partial_member_list", None)
|
||||
data.pop("external_knowledge_api_id", None)
|
||||
data.pop("external_knowledge_id", None)
|
||||
data.pop("external_retrieval_model", None)
|
||||
|
||||
# Filter out None values except for description field
|
||||
filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"}
|
||||
|
||||
# Handle indexing technique changes and embedding model updates
|
||||
action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data)
|
||||
|
||||
# Add metadata fields
|
||||
filtered_data["updated_by"] = user.id
|
||||
filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
# update Retrieval model
|
||||
filtered_data["retrieval_model"] = data["retrieval_model"]
|
||||
# update icon info
|
||||
if data.get("icon_info"):
|
||||
filtered_data["icon_info"] = data.get("icon_info")
|
||||
|
||||
# Update dataset in database
|
||||
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
|
||||
db.session.commit()
|
||||
|
||||
# Trigger vector index task if indexing technique changed
|
||||
if action:
|
||||
deal_dataset_vector_index_task.delay(dataset.id, action)
|
||||
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def _handle_indexing_technique_change(dataset, data, filtered_data):
|
||||
"""
|
||||
Handle changes in indexing technique and configure embedding models accordingly.
|
||||
|
||||
Args:
|
||||
dataset: Current dataset object
|
||||
data: Update data dictionary
|
||||
filtered_data: Filtered update data
|
||||
|
||||
Returns:
|
||||
str: Action to perform ('add', 'remove', 'update', or None)
|
||||
"""
|
||||
if dataset.indexing_technique != data["indexing_technique"]:
|
||||
if data["indexing_technique"] == "economy":
|
||||
# Remove embedding model configuration for economy mode
|
||||
filtered_data["embedding_model"] = None
|
||||
filtered_data["embedding_model_provider"] = None
|
||||
filtered_data["collection_binding_id"] = None
|
||||
return "remove"
|
||||
elif data["indexing_technique"] == "high_quality":
|
||||
# Configure embedding model for high quality mode
|
||||
DatasetService._configure_embedding_model_for_high_quality(data, filtered_data)
|
||||
return "add"
|
||||
else:
|
||||
# Handle embedding model updates when indexing technique remains the same
|
||||
return DatasetService._handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _configure_embedding_model_for_high_quality(data, filtered_data):
|
||||
"""
|
||||
Configure embedding model settings for high quality indexing.
|
||||
|
||||
Args:
|
||||
data: Update data dictionary
|
||||
filtered_data: Filtered update data to modify
|
||||
"""
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=data["embedding_model_provider"],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=data["embedding_model"],
|
||||
)
|
||||
filtered_data["embedding_model"] = embedding_model.model
|
||||
filtered_data["embedding_model_provider"] = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
|
||||
@staticmethod
|
||||
def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data):
|
||||
"""
|
||||
Handle embedding model updates when indexing technique remains the same.
|
||||
|
||||
Args:
|
||||
dataset: Current dataset object
|
||||
data: Update data dictionary
|
||||
filtered_data: Filtered update data to modify
|
||||
|
||||
Returns:
|
||||
str: Action to perform ('update' or None)
|
||||
"""
|
||||
# Skip embedding model checks if not provided in the update request
|
||||
if (
|
||||
"embedding_model_provider" not in data
|
||||
or "embedding_model" not in data
|
||||
or not data.get("embedding_model_provider")
|
||||
or not data.get("embedding_model")
|
||||
):
|
||||
DatasetService._preserve_existing_embedding_settings(dataset, filtered_data)
|
||||
return None
|
||||
else:
|
||||
return DatasetService._update_embedding_model_settings(dataset, data, filtered_data)
|
||||
|
||||
@staticmethod
|
||||
def _preserve_existing_embedding_settings(dataset, filtered_data):
|
||||
"""
|
||||
Preserve existing embedding model settings when not provided in update.
|
||||
|
||||
Args:
|
||||
dataset: Current dataset object
|
||||
filtered_data: Filtered update data to modify
|
||||
"""
|
||||
# If the dataset already has embedding model settings, use those
|
||||
if dataset.embedding_model_provider and dataset.embedding_model:
|
||||
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
|
||||
filtered_data["embedding_model"] = dataset.embedding_model
|
||||
# If collection_binding_id exists, keep it too
|
||||
if dataset.collection_binding_id:
|
||||
filtered_data["collection_binding_id"] = dataset.collection_binding_id
|
||||
# Otherwise, don't try to update embedding model settings at all
|
||||
# Remove these fields from filtered_data if they exist but are None/empty
|
||||
if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]:
|
||||
del filtered_data["embedding_model_provider"]
|
||||
if "embedding_model" in filtered_data and not filtered_data["embedding_model"]:
|
||||
del filtered_data["embedding_model"]
|
||||
|
||||
@staticmethod
|
||||
def _update_embedding_model_settings(dataset, data, filtered_data):
|
||||
"""
|
||||
Update embedding model settings with new values.
|
||||
|
||||
Args:
|
||||
dataset: Current dataset object
|
||||
data: Update data dictionary
|
||||
filtered_data: Filtered update data to modify
|
||||
|
||||
Returns:
|
||||
str: Action to perform ('update' or None)
|
||||
"""
|
||||
try:
|
||||
# Compare current and new model provider settings
|
||||
current_provider_str = (
|
||||
str(ModelProviderID(dataset.embedding_model_provider)) if dataset.embedding_model_provider else None
|
||||
)
|
||||
new_provider_str = (
|
||||
str(ModelProviderID(data["embedding_model_provider"])) if data["embedding_model_provider"] else None
|
||||
)
|
||||
|
||||
# Only update if values are different
|
||||
if current_provider_str != new_provider_str or data["embedding_model"] != dataset.embedding_model:
|
||||
DatasetService._apply_new_embedding_settings(dataset, data, filtered_data)
|
||||
return "update"
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _apply_new_embedding_settings(dataset, data, filtered_data):
|
||||
"""
|
||||
Apply new embedding model settings to the dataset.
|
||||
|
||||
Args:
|
||||
dataset: Current dataset object
|
||||
data: Update data dictionary
|
||||
filtered_data: Filtered update data to modify
|
||||
"""
|
||||
model_manager = ModelManager()
|
||||
try:
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=data["embedding_model_provider"],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=data["embedding_model"],
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
# If we can't get the embedding model, preserve existing settings
|
||||
if dataset.embedding_model_provider and dataset.embedding_model:
|
||||
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
|
||||
filtered_data["embedding_model"] = dataset.embedding_model
|
||||
if dataset.collection_binding_id:
|
||||
filtered_data["collection_binding_id"] = dataset.collection_binding_id
|
||||
# Skip the rest of the embedding model update
|
||||
return
|
||||
|
||||
# Apply new embedding model settings
|
||||
filtered_data["embedding_model"] = embedding_model.model
|
||||
filtered_data["embedding_model_provider"] = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||
|
||||
@staticmethod
|
||||
def update_rag_pipeline_dataset_settings(
|
||||
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
|
||||
@ -1157,12 +1314,17 @@ class DocumentService:
|
||||
process_rule = knowledge_config.process_rule
|
||||
if process_rule:
|
||||
if process_rule.mode in ("custom", "hierarchical"):
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
if process_rule.rules:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
else:
|
||||
dataset_process_rule = dataset.latest_process_rule
|
||||
if not dataset_process_rule:
|
||||
raise ValueError("No process rule found.")
|
||||
elif process_rule.mode == "automatic":
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
@ -2061,6 +2223,191 @@ class DocumentService:
|
||||
if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
|
||||
raise ValueError("Process rule segmentation max_tokens is invalid")
|
||||
|
||||
@staticmethod
|
||||
def batch_update_document_status(dataset: Dataset, document_ids: list[str], action: str, user):
|
||||
"""
|
||||
Batch update document status.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): The dataset object
|
||||
document_ids (list[str]): List of document IDs to update
|
||||
action (str): Action to perform (enable, disable, archive, un_archive)
|
||||
user: Current user performing the action
|
||||
|
||||
Raises:
|
||||
DocumentIndexingError: If document is being indexed or not in correct state
|
||||
ValueError: If action is invalid
|
||||
"""
|
||||
if not document_ids:
|
||||
return
|
||||
|
||||
# Early validation of action parameter
|
||||
valid_actions = ["enable", "disable", "archive", "un_archive"]
|
||||
if action not in valid_actions:
|
||||
raise ValueError(f"Invalid action: {action}. Must be one of {valid_actions}")
|
||||
|
||||
documents_to_update = []
|
||||
|
||||
# First pass: validate all documents and prepare updates
|
||||
for document_id in document_ids:
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
continue
|
||||
|
||||
# Check if document is being indexed
|
||||
indexing_cache_key = f"document_{document.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise DocumentIndexingError(f"Document:{document.name} is being indexed, please try again later")
|
||||
|
||||
# Prepare update based on action
|
||||
update_info = DocumentService._prepare_document_status_update(document, action, user)
|
||||
if update_info:
|
||||
documents_to_update.append(update_info)
|
||||
|
||||
# Second pass: apply all updates in a single transaction
|
||||
if documents_to_update:
|
||||
try:
|
||||
for update_info in documents_to_update:
|
||||
document = update_info["document"]
|
||||
updates = update_info["updates"]
|
||||
|
||||
# Apply updates to the document
|
||||
for field, value in updates.items():
|
||||
setattr(document, field, value)
|
||||
|
||||
db.session.add(document)
|
||||
|
||||
# Batch commit all changes
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
# Rollback on any error
|
||||
db.session.rollback()
|
||||
raise e
|
||||
# Execute async tasks and set Redis cache after successful commit
|
||||
# propagation_error is used to capture any errors for submitting async task execution
|
||||
propagation_error = None
|
||||
for update_info in documents_to_update:
|
||||
try:
|
||||
# Execute async tasks after successful commit
|
||||
if update_info["async_task"]:
|
||||
task_info = update_info["async_task"]
|
||||
task_func = task_info["function"]
|
||||
task_args = task_info["args"]
|
||||
task_func.delay(*task_args)
|
||||
except Exception as e:
|
||||
# Log the error but do not rollback the transaction
|
||||
logging.exception(f"Error executing async task for document {update_info['document'].id}")
|
||||
# don't raise the error immediately, but capture it for later
|
||||
propagation_error = e
|
||||
try:
|
||||
# Set Redis cache if needed after successful commit
|
||||
if update_info["set_cache"]:
|
||||
document = update_info["document"]
|
||||
indexing_cache_key = f"document_{document.id}_indexing"
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
except Exception as e:
|
||||
# Log the error but do not rollback the transaction
|
||||
logging.exception(f"Error setting cache for document {update_info['document'].id}")
|
||||
# Raise any propagation error after all updates
|
||||
if propagation_error:
|
||||
raise propagation_error
|
||||
|
||||
@staticmethod
|
||||
def _prepare_document_status_update(document, action: str, user):
|
||||
"""
|
||||
Prepare document status update information.
|
||||
|
||||
Args:
|
||||
document: Document object to update
|
||||
action: Action to perform
|
||||
user: Current user
|
||||
|
||||
Returns:
|
||||
dict: Update information or None if no update needed
|
||||
"""
|
||||
now = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
|
||||
if action == "enable":
|
||||
return DocumentService._prepare_enable_update(document, now)
|
||||
elif action == "disable":
|
||||
return DocumentService._prepare_disable_update(document, user, now)
|
||||
elif action == "archive":
|
||||
return DocumentService._prepare_archive_update(document, user, now)
|
||||
elif action == "un_archive":
|
||||
return DocumentService._prepare_unarchive_update(document, now)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _prepare_enable_update(document, now):
|
||||
"""Prepare updates for enabling a document."""
|
||||
if document.enabled:
|
||||
return None
|
||||
|
||||
return {
|
||||
"document": document,
|
||||
"updates": {"enabled": True, "disabled_at": None, "disabled_by": None, "updated_at": now},
|
||||
"async_task": {"function": add_document_to_index_task, "args": [document.id]},
|
||||
"set_cache": True,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _prepare_disable_update(document, user, now):
|
||||
"""Prepare updates for disabling a document."""
|
||||
if not document.completed_at or document.indexing_status != "completed":
|
||||
raise DocumentIndexingError(f"Document: {document.name} is not completed.")
|
||||
|
||||
if not document.enabled:
|
||||
return None
|
||||
|
||||
return {
|
||||
"document": document,
|
||||
"updates": {"enabled": False, "disabled_at": now, "disabled_by": user.id, "updated_at": now},
|
||||
"async_task": {"function": remove_document_from_index_task, "args": [document.id]},
|
||||
"set_cache": True,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _prepare_archive_update(document, user, now):
|
||||
"""Prepare updates for archiving a document."""
|
||||
if document.archived:
|
||||
return None
|
||||
|
||||
update_info = {
|
||||
"document": document,
|
||||
"updates": {"archived": True, "archived_at": now, "archived_by": user.id, "updated_at": now},
|
||||
"async_task": None,
|
||||
"set_cache": False,
|
||||
}
|
||||
|
||||
# Only set async task and cache if document is currently enabled
|
||||
if document.enabled:
|
||||
update_info["async_task"] = {"function": remove_document_from_index_task, "args": [document.id]}
|
||||
update_info["set_cache"] = True
|
||||
|
||||
return update_info
|
||||
|
||||
@staticmethod
|
||||
def _prepare_unarchive_update(document, now):
|
||||
"""Prepare updates for unarchiving a document."""
|
||||
if not document.archived:
|
||||
return None
|
||||
|
||||
update_info = {
|
||||
"document": document,
|
||||
"updates": {"archived": False, "archived_at": None, "archived_by": None, "updated_at": now},
|
||||
"async_task": None,
|
||||
"set_cache": False,
|
||||
}
|
||||
|
||||
# Only re-index if the document is currently enabled
|
||||
if document.enabled:
|
||||
update_info["async_task"] = {"function": add_document_to_index_task, "args": [document.id]}
|
||||
update_info["set_cache"] = True
|
||||
|
||||
return update_info
|
||||
|
||||
|
||||
class SegmentService:
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user