mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
fix: fix miss use db.session (#31971)
This commit is contained in:
@ -8,7 +8,6 @@ from sqlalchemy import delete, select
|
|||||||
from core.db.session_factory import session_factory
|
from core.db.session_factory import session_factory
|
||||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from extensions.ext_database import db
|
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models.dataset import Dataset, Document, DocumentSegment
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
|
|
||||||
@ -27,7 +26,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
|||||||
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
|
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
|
||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
|
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session, session.begin():
|
||||||
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||||
|
|
||||||
if not document:
|
if not document:
|
||||||
@ -36,7 +35,6 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
|||||||
|
|
||||||
document.indexing_status = "parsing"
|
document.indexing_status = "parsing"
|
||||||
document.processing_started_at = naive_utc_now()
|
document.processing_started_at = naive_utc_now()
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# delete all document segment and index
|
# delete all document segment and index
|
||||||
try:
|
try:
|
||||||
@ -56,7 +54,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
|||||||
segment_ids = [segment.id for segment in segments]
|
segment_ids = [segment.id for segment in segments]
|
||||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||||
session.execute(segment_delete_stmt)
|
session.execute(segment_delete_stmt)
|
||||||
db.session.commit()
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logger.info(
|
logger.info(
|
||||||
click.style(
|
click.style(
|
||||||
|
|||||||
@ -0,0 +1,182 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from faker import Faker
|
||||||
|
|
||||||
|
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||||
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
|
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentIndexingUpdateTask:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_external_dependencies(self):
|
||||||
|
"""Patch external collaborators used by the update task.
|
||||||
|
- IndexProcessorFactory.init_index_processor().clean(...)
|
||||||
|
- IndexingRunner.run([...])
|
||||||
|
"""
|
||||||
|
with (
|
||||||
|
patch("tasks.document_indexing_update_task.IndexProcessorFactory") as mock_factory,
|
||||||
|
patch("tasks.document_indexing_update_task.IndexingRunner") as mock_runner,
|
||||||
|
):
|
||||||
|
processor_instance = MagicMock()
|
||||||
|
mock_factory.return_value.init_index_processor.return_value = processor_instance
|
||||||
|
|
||||||
|
runner_instance = MagicMock()
|
||||||
|
mock_runner.return_value = runner_instance
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"factory": mock_factory,
|
||||||
|
"processor": processor_instance,
|
||||||
|
"runner": mock_runner,
|
||||||
|
"runner_instance": runner_instance,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_dataset_document_with_segments(self, db_session_with_containers, *, segment_count: int = 2):
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Account and tenant
|
||||||
|
account = Account(
|
||||||
|
email=fake.email(),
|
||||||
|
name=fake.name(),
|
||||||
|
interface_language="en-US",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(account)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
tenant = Tenant(name=fake.company(), status="normal")
|
||||||
|
db_session_with_containers.add(tenant)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
join = TenantAccountJoin(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
account_id=account.id,
|
||||||
|
role=TenantAccountRole.OWNER,
|
||||||
|
current=True,
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(join)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
# Dataset and document
|
||||||
|
dataset = Dataset(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
name=fake.company(),
|
||||||
|
description=fake.text(max_nb_chars=64),
|
||||||
|
data_source_type="upload_file",
|
||||||
|
indexing_technique="high_quality",
|
||||||
|
created_by=account.id,
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(dataset)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
document = Document(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
position=0,
|
||||||
|
data_source_type="upload_file",
|
||||||
|
batch="test_batch",
|
||||||
|
name=fake.file_name(),
|
||||||
|
created_from="upload_file",
|
||||||
|
created_by=account.id,
|
||||||
|
indexing_status="waiting",
|
||||||
|
enabled=True,
|
||||||
|
doc_form="text_model",
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(document)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
# Segments
|
||||||
|
node_ids = []
|
||||||
|
for i in range(segment_count):
|
||||||
|
node_id = f"node-{i + 1}"
|
||||||
|
seg = DocumentSegment(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
document_id=document.id,
|
||||||
|
position=i,
|
||||||
|
content=fake.text(max_nb_chars=32),
|
||||||
|
answer=None,
|
||||||
|
word_count=10,
|
||||||
|
tokens=5,
|
||||||
|
index_node_id=node_id,
|
||||||
|
status="completed",
|
||||||
|
created_by=account.id,
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(seg)
|
||||||
|
node_ids.append(node_id)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
# Refresh to ensure ORM state
|
||||||
|
db_session_with_containers.refresh(dataset)
|
||||||
|
db_session_with_containers.refresh(document)
|
||||||
|
|
||||||
|
return dataset, document, node_ids
|
||||||
|
|
||||||
|
def test_cleans_segments_and_reindexes(self, db_session_with_containers, mock_external_dependencies):
|
||||||
|
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
document_indexing_update_task(dataset.id, document.id)
|
||||||
|
|
||||||
|
# Ensure we see committed changes from another session
|
||||||
|
db_session_with_containers.expire_all()
|
||||||
|
|
||||||
|
# Assert document status updated before reindex
|
||||||
|
updated = db_session_with_containers.query(Document).where(Document.id == document.id).first()
|
||||||
|
assert updated.indexing_status == "parsing"
|
||||||
|
assert updated.processing_started_at is not None
|
||||||
|
|
||||||
|
# Segments should be deleted
|
||||||
|
remaining = (
|
||||||
|
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
|
||||||
|
)
|
||||||
|
assert remaining == 0
|
||||||
|
|
||||||
|
# Assert index processor clean was called with expected args
|
||||||
|
clean_call = mock_external_dependencies["processor"].clean.call_args
|
||||||
|
assert clean_call is not None
|
||||||
|
args, kwargs = clean_call
|
||||||
|
# args[0] is a Dataset instance (from another session) — validate by id
|
||||||
|
assert getattr(args[0], "id", None) == dataset.id
|
||||||
|
# args[1] should contain our node_ids
|
||||||
|
assert set(args[1]) == set(node_ids)
|
||||||
|
assert kwargs.get("with_keywords") is True
|
||||||
|
assert kwargs.get("delete_child_chunks") is True
|
||||||
|
|
||||||
|
# Assert indexing runner invoked with the updated document
|
||||||
|
run_call = mock_external_dependencies["runner_instance"].run.call_args
|
||||||
|
assert run_call is not None
|
||||||
|
run_docs = run_call[0][0]
|
||||||
|
assert len(run_docs) == 1
|
||||||
|
first = run_docs[0]
|
||||||
|
assert getattr(first, "id", None) == document.id
|
||||||
|
|
||||||
|
def test_clean_error_is_logged_and_indexing_continues(self, db_session_with_containers, mock_external_dependencies):
|
||||||
|
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
|
||||||
|
|
||||||
|
# Force clean to raise; task should continue to indexing
|
||||||
|
mock_external_dependencies["processor"].clean.side_effect = Exception("boom")
|
||||||
|
|
||||||
|
document_indexing_update_task(dataset.id, document.id)
|
||||||
|
|
||||||
|
# Ensure we see committed changes from another session
|
||||||
|
db_session_with_containers.expire_all()
|
||||||
|
|
||||||
|
# Indexing should still be triggered
|
||||||
|
mock_external_dependencies["runner_instance"].run.assert_called_once()
|
||||||
|
|
||||||
|
# Segments should remain (since clean failed before DB delete)
|
||||||
|
remaining = (
|
||||||
|
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
|
||||||
|
)
|
||||||
|
assert remaining > 0
|
||||||
|
|
||||||
|
def test_document_not_found_noop(self, db_session_with_containers, mock_external_dependencies):
|
||||||
|
fake = Faker()
|
||||||
|
# Act with non-existent document id
|
||||||
|
document_indexing_update_task(dataset_id=fake.uuid4(), document_id=fake.uuid4())
|
||||||
|
|
||||||
|
# Neither processor nor runner should be called
|
||||||
|
mock_external_dependencies["processor"].clean.assert_not_called()
|
||||||
|
mock_external_dependencies["runner_instance"].run.assert_not_called()
|
||||||
Reference in New Issue
Block a user