mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
Merge remote-tracking branch 'myori/main' into feat/collaboration
This commit is contained in:
@ -548,7 +548,7 @@ class UpdateConfig(BaseSettings):
|
||||
|
||||
class WorkflowVariableTruncationConfig(BaseSettings):
|
||||
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
|
||||
# 100KB
|
||||
# 1000 KiB
|
||||
1024_000,
|
||||
description="Maximum size for variable to trigger final truncation.",
|
||||
)
|
||||
|
||||
@ -49,62 +49,80 @@ class IndexingRunner:
|
||||
self.storage = storage
|
||||
self.model_manager = ModelManager()
|
||||
|
||||
def _handle_indexing_error(self, document_id: str, error: Exception) -> None:
|
||||
"""Handle indexing errors by updating document status."""
|
||||
logger.exception("consume document failed")
|
||||
document = db.session.get(DatasetDocument, document_id)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
error_message = getattr(error, "description", str(error))
|
||||
document.error = str(error_message)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
def run(self, dataset_documents: list[DatasetDocument]):
|
||||
"""Run the indexing process."""
|
||||
for dataset_document in dataset_documents:
|
||||
document_id = dataset_document.id
|
||||
try:
|
||||
# Re-query the document to ensure it's bound to the current session
|
||||
requeried_document = db.session.get(DatasetDocument, document_id)
|
||||
if not requeried_document:
|
||||
logger.warning("Document not found, skipping document id: %s", document_id)
|
||||
continue
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
# get the process rule
|
||||
stmt = select(DatasetProcessRule).where(
|
||||
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
|
||||
DatasetProcessRule.id == requeried_document.dataset_process_rule_id
|
||||
)
|
||||
processing_rule = db.session.scalar(stmt)
|
||||
if not processing_rule:
|
||||
raise ValueError("no process rule found")
|
||||
index_type = dataset_document.doc_form
|
||||
index_type = requeried_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# extract
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
||||
|
||||
# transform
|
||||
documents = self._transform(
|
||||
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
|
||||
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
|
||||
)
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
self._load_segments(dataset, requeried_document, documents)
|
||||
|
||||
# load
|
||||
self._load(
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
dataset_document=requeried_document,
|
||||
documents=documents,
|
||||
)
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
except ObjectDeletedError:
|
||||
logger.warning("Document deleted, document id: %s", dataset_document.id)
|
||||
logger.warning("Document deleted, document id: %s", document_id)
|
||||
except Exception as e:
|
||||
logger.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
|
||||
def run_in_splitting_status(self, dataset_document: DatasetDocument):
|
||||
"""Run the indexing process when the index_status is splitting."""
|
||||
document_id = dataset_document.id
|
||||
try:
|
||||
# Re-query the document to ensure it's bound to the current session
|
||||
requeried_document = db.session.get(DatasetDocument, document_id)
|
||||
if not requeried_document:
|
||||
logger.warning("Document not found: %s", document_id)
|
||||
return
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
@ -112,57 +130,60 @@ class IndexingRunner:
|
||||
# get exist document_segment list and delete
|
||||
document_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
|
||||
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for document_segment in document_segments:
|
||||
db.session.delete(document_segment)
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
# delete child chunks
|
||||
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
|
||||
db.session.commit()
|
||||
# get the process rule
|
||||
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id)
|
||||
processing_rule = db.session.scalar(stmt)
|
||||
if not processing_rule:
|
||||
raise ValueError("no process rule found")
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_type = requeried_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# extract
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
||||
|
||||
# transform
|
||||
documents = self._transform(
|
||||
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
|
||||
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
|
||||
)
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
self._load_segments(dataset, requeried_document, documents)
|
||||
|
||||
# load
|
||||
self._load(
|
||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=requeried_document,
|
||||
documents=documents,
|
||||
)
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
except Exception as e:
|
||||
logger.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
|
||||
def run_in_indexing_status(self, dataset_document: DatasetDocument):
|
||||
"""Run the indexing process when the index_status is indexing."""
|
||||
document_id = dataset_document.id
|
||||
try:
|
||||
# Re-query the document to ensure it's bound to the current session
|
||||
requeried_document = db.session.get(DatasetDocument, document_id)
|
||||
if not requeried_document:
|
||||
logger.warning("Document not found: %s", document_id)
|
||||
return
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
@ -170,7 +191,7 @@ class IndexingRunner:
|
||||
# get exist document_segment list and delete
|
||||
document_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
|
||||
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
@ -188,7 +209,7 @@ class IndexingRunner:
|
||||
"dataset_id": document_segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = document_segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
@ -206,24 +227,20 @@ class IndexingRunner:
|
||||
document.children = child_documents
|
||||
documents.append(document)
|
||||
# build index
|
||||
index_type = dataset_document.doc_form
|
||||
index_type = requeried_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
self._load(
|
||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=requeried_document,
|
||||
documents=documents,
|
||||
)
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
except Exception as e:
|
||||
logger.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
|
||||
def indexing_estimate(
|
||||
self,
|
||||
|
||||
@ -100,7 +100,7 @@ class LLMGenerator:
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
|
||||
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]:
|
||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
@ -119,6 +119,8 @@ class LLMGenerator:
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
|
||||
questions: Sequence[str] = []
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
|
||||
@ -1,17 +1,26 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerOutputParser:
|
||||
def get_format_instructions(self) -> str:
|
||||
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
def parse(self, text: str):
|
||||
def parse(self, text: str) -> Sequence[str]:
|
||||
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
|
||||
questions: list[str] = []
|
||||
if action_match is not None:
|
||||
json_obj = json.loads(action_match.group(0).strip())
|
||||
else:
|
||||
json_obj = []
|
||||
return json_obj
|
||||
try:
|
||||
json_obj = json.loads(action_match.group(0).strip())
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.warning("Failed to decode suggested questions payload: %s", exc)
|
||||
else:
|
||||
if isinstance(json_obj, list):
|
||||
questions = [question for question in json_obj if isinstance(question, str)]
|
||||
return questions
|
||||
|
||||
@ -441,10 +441,14 @@ class LLMNode(Node):
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
full_text_buffer = io.StringIO()
|
||||
collected_structured_output = None # Collect structured_output from streaming chunks
|
||||
# Consume the invoke result and handle generator exception
|
||||
try:
|
||||
for result in invoke_result:
|
||||
if isinstance(result, LLMResultChunkWithStructuredOutput):
|
||||
# Collect structured_output from the chunk
|
||||
if result.structured_output is not None:
|
||||
collected_structured_output = dict(result.structured_output)
|
||||
yield result
|
||||
if isinstance(result, LLMResultChunk):
|
||||
contents = result.delta.message.content
|
||||
@ -492,6 +496,8 @@ class LLMNode(Node):
|
||||
finish_reason=finish_reason,
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
# Pass structured output if collected from streaming chunks
|
||||
structured_output=collected_structured_output,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -747,7 +747,7 @@ class ParameterExtractorNode(Node):
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
|
||||
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
|
||||
@ -135,7 +135,7 @@ Here are the chat histories between human and assistant, inside <histories></his
|
||||
### Instructions:
|
||||
Some extra information are provided below, you should always follow the instructions as possible as you can.
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
{instructions}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
|
||||
@ -6,8 +6,8 @@ from tasks.clean_dataset_task import clean_dataset_task
|
||||
@dataset_was_deleted.connect
|
||||
def handle(sender: Dataset, **kwargs):
|
||||
dataset = sender
|
||||
assert dataset.doc_form
|
||||
assert dataset.indexing_technique
|
||||
if not dataset.doc_form or not dataset.indexing_technique:
|
||||
return
|
||||
clean_dataset_task.delay(
|
||||
dataset.id,
|
||||
dataset.tenant_id,
|
||||
|
||||
@ -8,6 +8,6 @@ def handle(sender, **kwargs):
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
doc_form = kwargs.get("doc_form")
|
||||
file_id = kwargs.get("file_id")
|
||||
assert dataset_id is not None
|
||||
assert doc_form is not None
|
||||
if not dataset_id or not doc_form:
|
||||
return
|
||||
clean_document_task.delay(document_id, dataset_id, doc_form, file_id)
|
||||
|
||||
@ -1,7 +1,12 @@
|
||||
from configs import dify_config
|
||||
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN
|
||||
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN, HEADER_NAME_PASSPORT
|
||||
from dify_app import DifyApp
|
||||
|
||||
BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEADER_NAME_PASSPORT)
|
||||
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
|
||||
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
# register blueprint routers
|
||||
@ -17,7 +22,7 @@ def init_app(app: DifyApp):
|
||||
|
||||
CORS(
|
||||
service_api_bp,
|
||||
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE],
|
||||
allow_headers=list(SERVICE_API_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
@ -26,7 +31,7 @@ def init_app(app: DifyApp):
|
||||
web_bp,
|
||||
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN],
|
||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
@ -36,7 +41,7 @@ def init_app(app: DifyApp):
|
||||
console_app_bp,
|
||||
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization", HEADER_NAME_CSRF_TOKEN],
|
||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
@ -44,7 +49,7 @@ def init_app(app: DifyApp):
|
||||
|
||||
CORS(
|
||||
files_bp,
|
||||
allow_headers=["Content-Type", HEADER_NAME_CSRF_TOKEN],
|
||||
allow_headers=list(FILES_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
)
|
||||
app.register_blueprint(files_bp)
|
||||
|
||||
@ -22,55 +22,6 @@ def upgrade():
|
||||
batch_op.add_column(sa.Column('app_mode', sa.String(length=255), nullable=True))
|
||||
batch_op.create_index('message_app_mode_idx', ['app_mode'], unique=False)
|
||||
|
||||
conn = op.get_bind()
|
||||
|
||||
# Strategy: Update in batches to minimize lock time
|
||||
# For large tables (millions of rows), this prevents long-running transactions
|
||||
batch_size = 10000
|
||||
|
||||
print("Starting backfill of app_mode from conversations...")
|
||||
|
||||
# Use a more efficient UPDATE with JOIN
|
||||
# This query updates messages.app_mode from conversations.mode
|
||||
# Using string formatting for LIMIT since it's a constant
|
||||
update_query = f"""
|
||||
UPDATE messages m
|
||||
SET app_mode = c.mode
|
||||
FROM conversations c
|
||||
WHERE m.conversation_id = c.id
|
||||
AND m.app_mode IS NULL
|
||||
AND m.id IN (
|
||||
SELECT id FROM messages
|
||||
WHERE app_mode IS NULL
|
||||
LIMIT {batch_size}
|
||||
)
|
||||
"""
|
||||
|
||||
# Execute batched updates
|
||||
total_updated = 0
|
||||
iteration = 0
|
||||
while True:
|
||||
iteration += 1
|
||||
result = conn.execute(sa.text(update_query))
|
||||
|
||||
# Check if result is None or has no rowcount
|
||||
if result is None:
|
||||
print("Warning: Query returned None, stopping backfill")
|
||||
break
|
||||
|
||||
rows_updated = result.rowcount if hasattr(result, 'rowcount') else 0
|
||||
total_updated += rows_updated
|
||||
|
||||
if rows_updated == 0:
|
||||
break
|
||||
|
||||
print(f"Iteration {iteration}: Updated {rows_updated} messages (total: {total_updated})")
|
||||
|
||||
# For very large tables, add a small delay to reduce load
|
||||
# Uncomment if needed: import time; time.sleep(0.1)
|
||||
|
||||
print(f"Backfill completed. Total messages updated: {total_updated}")
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
|
||||
@ -64,7 +64,7 @@ dependencies = [
|
||||
"pycryptodome==3.19.1",
|
||||
"pydantic~=2.11.4",
|
||||
"pydantic-extra-types~=2.10.3",
|
||||
"pydantic-settings~=2.9.1",
|
||||
"pydantic-settings~=2.11.0",
|
||||
"pyjwt~=2.10.1",
|
||||
"pypdfium2==4.30.0",
|
||||
"python-docx~=1.1.0",
|
||||
|
||||
@ -288,9 +288,10 @@ class MessageService:
|
||||
)
|
||||
|
||||
with measure_time() as timer:
|
||||
questions: list[str] = LLMGenerator.generate_suggested_questions_after_answer(
|
||||
questions_sequence = LLMGenerator.generate_suggested_questions_after_answer(
|
||||
tenant_id=app_model.tenant_id, histories=histories
|
||||
)
|
||||
questions: list[str] = list(questions_sequence)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_id=app_model.id)
|
||||
|
||||
@ -79,7 +79,7 @@ class VariableTruncator:
|
||||
self,
|
||||
string_length_limit=5000,
|
||||
array_element_limit: int = 20,
|
||||
max_size_bytes: int = 1024_000, # 100KB
|
||||
max_size_bytes: int = 1024_000, # 1000 KiB
|
||||
):
|
||||
if string_length_limit <= 3:
|
||||
raise ValueError("string_length_limit should be greater than 3.")
|
||||
|
||||
@ -0,0 +1,216 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
||||
class DatasetDeleteTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for dataset delete tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
tenant_id: str = "test-tenant-123",
|
||||
created_by: str = "creator-456",
|
||||
doc_form: str | None = None,
|
||||
indexing_technique: str | None = "high_quality",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.created_by = created_by
|
||||
dataset.doc_form = doc_form
|
||||
dataset.indexing_technique = indexing_technique
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_user_mock(
|
||||
user_id: str = "user-789",
|
||||
tenant_id: str = "test-tenant-123",
|
||||
role: TenantAccountRole = TenantAccountRole.ADMIN,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock user with specified attributes."""
|
||||
user = Mock(spec=Account)
|
||||
user.id = user_id
|
||||
user.current_tenant_id = tenant_id
|
||||
user.current_role = role
|
||||
for key, value in kwargs.items():
|
||||
setattr(user, key, value)
|
||||
return user
|
||||
|
||||
|
||||
class TestDatasetServiceDeleteDataset:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.delete_dataset method.
|
||||
|
||||
This test suite covers all deletion scenarios including:
|
||||
- Normal dataset deletion with documents
|
||||
- Empty dataset deletion (no documents, doc_form is None)
|
||||
- Dataset deletion with missing indexing_technique
|
||||
- Permission checks
|
||||
- Event handling
|
||||
|
||||
This test suite provides regression protection for issue #27073.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset_service_dependencies(self):
|
||||
"""Common mock setup for dataset service dependencies."""
|
||||
with (
|
||||
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
|
||||
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
|
||||
patch("extensions.ext_database.db.session") as mock_db,
|
||||
patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted,
|
||||
):
|
||||
yield {
|
||||
"get_dataset": mock_get_dataset,
|
||||
"check_permission": mock_check_perm,
|
||||
"db_session": mock_db,
|
||||
"dataset_was_deleted": mock_dataset_was_deleted,
|
||||
}
|
||||
|
||||
def test_delete_dataset_with_documents_success(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test successful deletion of a dataset with documents.
|
||||
|
||||
This test verifies:
|
||||
- Dataset is retrieved correctly
|
||||
- Permission check is performed
|
||||
- dataset_was_deleted event is sent
|
||||
- Dataset is deleted from database
|
||||
- Method returns True
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DatasetDeleteTestDataFactory.create_dataset_mock(
|
||||
doc_form="text_model", indexing_technique="high_quality"
|
||||
)
|
||||
user = DatasetDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
# Act
|
||||
result = DatasetService.delete_dataset(dataset.id, user)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id)
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
|
||||
|
||||
def test_delete_empty_dataset_success(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test successful deletion of an empty dataset (no documents, doc_form is None).
|
||||
|
||||
This test verifies that:
|
||||
- Empty datasets can be deleted without errors
|
||||
- dataset_was_deleted event is sent (event handler will skip cleanup if doc_form is None)
|
||||
- Dataset is deleted from database
|
||||
- Method returns True
|
||||
|
||||
This is the primary test for issue #27073 where deleting an empty dataset
|
||||
caused internal server error due to assertion failure in event handlers.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique=None)
|
||||
user = DatasetDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
# Act
|
||||
result = DatasetService.delete_dataset(dataset.id, user)
|
||||
|
||||
# Assert - Verify complete deletion flow
|
||||
assert result is True
|
||||
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id)
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
|
||||
|
||||
def test_delete_dataset_with_partial_none_values(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test deletion of dataset with partial None values.
|
||||
|
||||
This test verifies that datasets with partial None values (e.g., doc_form exists
|
||||
but indexing_technique is None) can be deleted successfully. The event handler
|
||||
will skip cleanup if any required field is None.
|
||||
|
||||
Improvement based on Gemini Code Assist suggestion: Added comprehensive assertions
|
||||
to verify all core deletion operations are performed, not just event sending.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form="text_model", indexing_technique=None)
|
||||
user = DatasetDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
# Act
|
||||
result = DatasetService.delete_dataset(dataset.id, user)
|
||||
|
||||
# Assert - Verify complete deletion flow (Gemini suggestion implemented)
|
||||
assert result is True
|
||||
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id)
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
|
||||
|
||||
def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test deletion of dataset where doc_form is None but indexing_technique exists.
|
||||
|
||||
This edge case can occur in certain dataset configurations and should be handled
|
||||
gracefully by the event handler's conditional check.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique="high_quality")
|
||||
user = DatasetDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
# Act
|
||||
result = DatasetService.delete_dataset(dataset.id, user)
|
||||
|
||||
# Assert - Verify complete deletion flow
|
||||
assert result is True
|
||||
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id)
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
|
||||
|
||||
def test_delete_dataset_not_found(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test deletion attempt when dataset doesn't exist.
|
||||
|
||||
This test verifies that:
|
||||
- Method returns False when dataset is not found
|
||||
- No deletion operations are performed
|
||||
- No events are sent
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "non-existent-dataset"
|
||||
user = DatasetDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = None
|
||||
|
||||
# Act
|
||||
result = DatasetService.delete_dataset(dataset_id, user)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id)
|
||||
mock_dataset_service_dependencies["check_permission"].assert_not_called()
|
||||
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_not_called()
|
||||
mock_dataset_service_dependencies["db_session"].delete.assert_not_called()
|
||||
mock_dataset_service_dependencies["db_session"].commit.assert_not_called()
|
||||
8
api/uv.lock
generated
8
api/uv.lock
generated
@ -1557,7 +1557,7 @@ requires-dist = [
|
||||
{ name = "pycryptodome", specifier = "==3.19.1" },
|
||||
{ name = "pydantic", specifier = "~=2.11.4" },
|
||||
{ name = "pydantic-extra-types", specifier = "~=2.10.3" },
|
||||
{ name = "pydantic-settings", specifier = "~=2.9.1" },
|
||||
{ name = "pydantic-settings", specifier = "~=2.11.0" },
|
||||
{ name = "pyjwt", specifier = "~=2.10.1" },
|
||||
{ name = "pypdfium2", specifier = "==4.30.0" },
|
||||
{ name = "python-docx", specifier = "~=1.1.0" },
|
||||
@ -4779,16 +4779,16 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pydantic-settings"
|
||||
version = "2.9.1"
|
||||
version = "2.11.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pydantic" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "typing-inspection" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/67/1d/42628a2c33e93f8e9acbde0d5d735fa0850f3e6a2f8cb1eb6c40b9a732ac/pydantic_settings-2.9.1.tar.gz", hash = "sha256:c509bf79d27563add44e8446233359004ed85066cd096d8b510f715e6ef5d268", size = 163234, upload-time = "2025-04-18T16:44:48.265Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/20/c5/dbbc27b814c71676593d1c3f718e6cd7d4f00652cefa24b75f7aa3efb25e/pydantic_settings-2.11.0.tar.gz", hash = "sha256:d0e87a1c7d33593beb7194adb8470fc426e95ba02af83a0f23474a04c9a08180", size = 188394, upload-time = "2025-09-24T14:19:11.764Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b6/5f/d6d641b490fd3ec2c4c13b4244d68deea3a1b970a97be64f34fb5504ff72/pydantic_settings-2.9.1-py3-none-any.whl", hash = "sha256:59b4f431b1defb26fe620c71a7d3968a710d719f5f4cdbbdb7926edeb770f6ef", size = 44356, upload-time = "2025-04-18T16:44:46.617Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/83/d6/887a1ff844e64aa823fb4905978d882a633cfe295c32eacad582b78a7d8b/pydantic_settings-2.11.0-py3-none-any.whl", hash = "sha256:fe2cea3413b9530d10f3a5875adffb17ada5c1e1bab0b2885546d7310415207c", size = 48608, upload-time = "2025-09-24T14:19:10.015Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user