Merge remote-tracking branch 'origin/main' into feat/support-agent-sandbox

This commit is contained in:
yyh
2026-03-25 11:50:33 +08:00
82 changed files with 1074 additions and 1135 deletions

View File

@ -1476,8 +1476,8 @@ class TestDatasetIndexingStatusApi:
return_value=MagicMock(all=lambda: [document]),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
"controllers.console.datasets.datasets.db.session.scalar",
return_value=3,
),
):
response, status = method(api, "dataset-1")
@ -1526,13 +1526,6 @@ class TestDatasetIndexingStatusApi:
document.error = None
document.stopped_at = None
# First count = completed segments, second = total segments
query_mock = MagicMock()
query_mock.where.side_effect = [
MagicMock(count=lambda: 2),
MagicMock(count=lambda: 5),
]
with (
app.test_request_context("/"),
patch(
@ -1544,8 +1537,8 @@ class TestDatasetIndexingStatusApi:
return_value=MagicMock(all=lambda: [document]),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=query_mock,
"controllers.console.datasets.datasets.db.session.scalar",
side_effect=[2, 5],
),
):
response, status = method(api, "dataset-1")
@ -1591,8 +1584,8 @@ class TestDatasetApiKeyApi:
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
"controllers.console.datasets.datasets.db.session.scalar",
return_value=3,
),
patch(
"controllers.console.datasets.datasets.ApiToken.generate_api_key",
@ -1625,8 +1618,8 @@ class TestDatasetApiKeyApi:
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)),
"controllers.console.datasets.datasets.db.session.scalar",
return_value=10,
),
):
with pytest.raises(BadRequest) as exc_info:
@ -1653,8 +1646,8 @@ class TestDatasetApiDeleteApi:
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)),
"controllers.console.datasets.datasets.db.session.scalar",
return_value=mock_key,
),
patch(
"controllers.console.datasets.datasets.db.session.commit",
@ -1681,8 +1674,8 @@ class TestDatasetApiDeleteApi:
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)),
"controllers.console.datasets.datasets.db.session.scalar",
return_value=None,
),
):
with pytest.raises(NotFound):

View File

@ -526,8 +526,8 @@ class TestDatasetDocumentSegmentUpdateApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=segment,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -621,8 +621,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=upload_file,
),
patch(
"controllers.console.datasets.datasets_segments.redis_client.setnx",
@ -706,8 +706,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: None)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=None,
),
):
with pytest.raises(NotFound):
@ -738,8 +738,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=upload_file,
),
):
with pytest.raises(ValueError):
@ -770,8 +770,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=upload_file,
),
patch(
"controllers.console.datasets.datasets_segments.redis_client.setnx",
@ -831,8 +831,8 @@ class TestChildChunkAddApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=segment,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -880,8 +880,8 @@ class TestChildChunkAddApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=segment,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -924,11 +924,8 @@ class TestChildChunkUpdateApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
side_effect=[
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)),
],
"controllers.console.datasets.datasets_segments.db.session.scalar",
side_effect=[segment, child_chunk],
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -970,11 +967,8 @@ class TestChildChunkUpdateApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
side_effect=[
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)),
],
"controllers.console.datasets.datasets_segments.db.session.scalar",
side_effect=[segment, child_chunk],
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -1180,8 +1174,8 @@ class TestSegmentOperationCases:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=upload_file,
),
):
with pytest.raises(NotFound):
@ -1215,8 +1209,8 @@ class TestSegmentOperationCases:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=upload_file,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",

View File

@ -4,6 +4,7 @@ from unittest.mock import Mock, patch
import pytest
from core.entities.knowledge_entities import PreviewDetail
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
from core.rag.models.document import AttachmentDocument, Document
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
@ -21,7 +22,7 @@ class TestParagraphIndexProcessor:
dataset = Mock()
dataset.id = "dataset-1"
dataset.tenant_id = "tenant-1"
dataset.indexing_technique = "high_quality"
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
dataset.is_multimodal = True
return dataset
@ -167,7 +168,7 @@ class TestParagraphIndexProcessor:
def test_load_uses_keyword_add_texts_with_keywords_when_economy(
self, processor: ParagraphIndexProcessor, dataset: Mock
) -> None:
dataset.indexing_technique = "economy"
dataset.indexing_technique = IndexTechniqueType.ECONOMY
docs = [Document(page_content="chunk", metadata={})]
with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls:
@ -178,7 +179,7 @@ class TestParagraphIndexProcessor:
def test_load_uses_keyword_add_texts_without_keywords_when_economy(
self, processor: ParagraphIndexProcessor, dataset: Mock
) -> None:
dataset.indexing_technique = "economy"
dataset.indexing_technique = IndexTechniqueType.ECONOMY
docs = [Document(page_content="chunk", metadata={})]
with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls:
@ -208,7 +209,7 @@ class TestParagraphIndexProcessor:
def test_clean_economy_deletes_summaries_and_keywords(
self, processor: ParagraphIndexProcessor, dataset: Mock
) -> None:
dataset.indexing_technique = "economy"
dataset.indexing_technique = IndexTechniqueType.ECONOMY
with (
patch(
@ -222,7 +223,7 @@ class TestParagraphIndexProcessor:
mock_keyword_cls.return_value.delete.assert_called_once()
def test_clean_deletes_keywords_by_ids(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
dataset.indexing_technique = "economy"
dataset.indexing_technique = IndexTechniqueType.ECONOMY
with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls:
processor.clean(dataset, ["node-2"], with_keywords=True)
@ -267,7 +268,7 @@ class TestParagraphIndexProcessor:
def test_index_list_chunks_economy(
self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock
) -> None:
dataset.indexing_technique = "economy"
dataset.indexing_technique = IndexTechniqueType.ECONOMY
with (
patch(
"core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash",

View File

@ -4,6 +4,7 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
from core.entities.knowledge_entities import PreviewDetail
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from services.entities.knowledge_entities.knowledge_entities import ParentMode
@ -19,7 +20,7 @@ class TestParentChildIndexProcessor:
dataset = Mock()
dataset.id = "dataset-1"
dataset.tenant_id = "tenant-1"
dataset.indexing_technique = "high_quality"
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
dataset.is_multimodal = True
return dataset

View File

@ -6,6 +6,7 @@ import pytest
from werkzeug.datastructures import FileStorage
from core.entities.knowledge_entities import PreviewDetail
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
from core.rag.models.document import AttachmentDocument, Document
@ -33,7 +34,7 @@ class TestQAIndexProcessor:
dataset = Mock()
dataset.id = "dataset-1"
dataset.tenant_id = "tenant-1"
dataset.indexing_technique = "high_quality"
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
dataset.is_multimodal = True
return dataset
@ -207,7 +208,7 @@ class TestQAIndexProcessor:
vector.create_multimodal.assert_called_once_with(multimodal_docs)
def test_load_skips_vector_for_non_high_quality(self, processor: QAIndexProcessor, dataset: Mock) -> None:
dataset.indexing_technique = "economy"
dataset.indexing_technique = IndexTechniqueType.ECONOMY
docs = [Document(page_content="Q1", metadata={"answer": "A1"})]
with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls:
@ -298,7 +299,7 @@ class TestQAIndexProcessor:
def test_index_requires_high_quality(
self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock
) -> None:
dataset.indexing_technique = "economy"
dataset.indexing_technique = IndexTechniqueType.ECONOMY
qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")])
with (

View File

@ -61,7 +61,7 @@ from core.indexing_runner import (
DocumentIsPausedError,
IndexingRunner,
)
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.models.document import ChildDocument, Document
from dify_graph.model_runtime.entities.model_entities import ModelType
from libs.datetime_utils import naive_utc_now
@ -76,7 +76,7 @@ from models.dataset import Document as DatasetDocument
def create_mock_dataset(
dataset_id: str | None = None,
tenant_id: str | None = None,
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
embedding_provider: str = "openai",
embedding_model: str = "text-embedding-ada-002",
) -> Mock:
@ -458,7 +458,7 @@ class TestIndexingRunnerTransform:
dataset = Mock(spec=Dataset)
dataset.id = str(uuid.uuid4())
dataset.tenant_id = str(uuid.uuid4())
dataset.indexing_technique = "high_quality"
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
dataset.embedding_model_provider = "openai"
dataset.embedding_model = "text-embedding-ada-002"
return dataset
@ -521,7 +521,7 @@ class TestIndexingRunnerTransform:
"""Test transformation with economy indexing (no embeddings)."""
# Arrange
runner = IndexingRunner()
sample_dataset.indexing_technique = "economy"
sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY
mock_processor = MagicMock()
transformed_docs = [
@ -605,7 +605,7 @@ class TestIndexingRunnerLoad:
dataset = Mock(spec=Dataset)
dataset.id = str(uuid.uuid4())
dataset.tenant_id = str(uuid.uuid4())
dataset.indexing_technique = "high_quality"
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
dataset.embedding_model_provider = "openai"
dataset.embedding_model = "text-embedding-ada-002"
return dataset
@ -674,7 +674,7 @@ class TestIndexingRunnerLoad:
"""Test loading with economy indexing (keyword only)."""
# Arrange
runner = IndexingRunner()
sample_dataset.indexing_technique = "economy"
sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY
mock_processor = MagicMock()
@ -701,7 +701,7 @@ class TestIndexingRunnerLoad:
# Arrange
runner = IndexingRunner()
sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
sample_dataset.indexing_technique = "high_quality"
sample_dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
# Add child documents
for doc in sample_documents:
@ -795,7 +795,7 @@ class TestIndexingRunnerRun:
mock_dataset = Mock(spec=Dataset)
mock_dataset.id = doc.dataset_id
mock_dataset.tenant_id = doc.tenant_id
mock_dataset.indexing_technique = "economy"
mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset
mock_process_rule = Mock(spec=DatasetProcessRule)
@ -949,7 +949,7 @@ class TestIndexingRunnerRun:
mock_dependencies["db"].session.get.side_effect = get_side_effect
mock_dataset = Mock(spec=Dataset)
mock_dataset.indexing_technique = "economy"
mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset
mock_process_rule = Mock(spec=DatasetProcessRule)

View File

@ -5,6 +5,7 @@ from unittest.mock import Mock
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.workflow.nodes.knowledge_index.entities import KnowledgeIndexNodeData
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
@ -78,7 +79,7 @@ def sample_node_data():
type="knowledge-index",
chunk_structure="general_structure",
index_chunk_variable_selector=["start", "chunks"],
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
summary_index_setting=None,
)

View File

@ -1,7 +1,11 @@
import threading
import time
from dataclasses import dataclass
from typing import cast
import pytest
from libs.broadcast_channel.exc import SubscriptionClosedError
from libs.broadcast_channel.redis.streams_channel import (
StreamsBroadcastChannel,
StreamsTopic,
@ -22,6 +26,7 @@ class FakeStreamsRedis:
self._store: dict[str, list[tuple[str, dict]]] = {}
self._next_id: dict[str, int] = {}
self._expire_calls: dict[str, int] = {}
self._dollar_snapshots: dict[str, int] = {}
# Publisher API
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
@ -47,7 +52,9 @@ class FakeStreamsRedis:
# Find position strictly greater than last_id
start_idx = 0
if last_id != "0-0":
if last_id == "$":
start_idx = self._dollar_snapshots.setdefault(key, len(entries))
elif last_id != "0-0":
for i, (eid, _f) in enumerate(entries):
if eid == last_id:
start_idx = i + 1
@ -63,6 +70,55 @@ class FakeStreamsRedis:
return [(key, batch)]
class FailExpireRedis(FakeStreamsRedis):
def expire(self, key: str, seconds: int) -> None:
raise RuntimeError("expire failed")
class BlockingRedis:
def __init__(self) -> None:
self._release = threading.Event()
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
self._release.wait(timeout=block / 1000.0 if block else None)
return []
def release(self) -> None:
self._release.set()
@dataclass(frozen=True)
class ListenPayloadCase:
name: str
fields: object
expected_messages: list[bytes]
def build_listen_payload_cases() -> list[ListenPayloadCase]:
return [
ListenPayloadCase(
name="string_payload_is_encoded",
fields={b"data": "hello"},
expected_messages=[b"hello"],
),
ListenPayloadCase(
name="bytearray_payload_is_converted",
fields={b"data": bytearray(b"world")},
expected_messages=[b"world"],
),
ListenPayloadCase(
name="non_dict_fields_are_ignored",
fields=[("data", b"ignored")],
expected_messages=[],
),
ListenPayloadCase(
name="missing_payload_is_ignored",
fields={b"other": b"ignored"},
expected_messages=[],
),
]
@pytest.fixture
def fake_redis() -> FakeStreamsRedis:
return FakeStreamsRedis()
@ -94,21 +150,37 @@ class TestStreamsBroadcastChannel:
# Expire called after publish
assert fake_redis._expire_calls.get("stream:beta", 0) >= 1
def test_topic_exposes_self_as_producer_and_subscriber(self, streams_channel: StreamsBroadcastChannel):
topic = streams_channel.topic("producer-subscriber")
assert topic.as_producer() is topic
assert topic.as_subscriber() is topic
def test_publish_logs_warning_when_expire_fails(self, caplog: pytest.LogCaptureFixture):
channel = StreamsBroadcastChannel(FailExpireRedis(), retention_seconds=60)
topic = channel.topic("expire-warning")
topic.publish(b"payload")
assert "Failed to set expire for stream key" in caplog.text
class TestStreamsSubscription:
def test_subscribe_and_receive_from_beginning(self, streams_channel: StreamsBroadcastChannel):
def test_subscribe_only_receives_messages_published_after_subscription_starts(
self,
streams_channel: StreamsBroadcastChannel,
):
topic = streams_channel.topic("gamma")
# Pre-publish events before subscribing (late subscriber)
topic.publish(b"e1")
topic.publish(b"e2")
topic.publish(b"before-subscribe")
sub = topic.subscribe()
assert isinstance(sub, _StreamsSubscription)
received: list[bytes] = []
with sub:
# Give listener thread a moment to xread
time.sleep(0.05)
assert sub.receive(timeout=0.05) is None
topic.publish(b"after-subscribe-1")
topic.publish(b"after-subscribe-2")
# Drain using receive() to avoid indefinite iteration in tests
for _ in range(5):
msg = sub.receive(timeout=0.1)
@ -116,7 +188,7 @@ class TestStreamsSubscription:
break
received.append(msg)
assert received == [b"e1", b"e2"]
assert received == [b"after-subscribe-1", b"after-subscribe-2"]
def test_receive_timeout_returns_none(self, streams_channel: StreamsBroadcastChannel):
topic = streams_channel.topic("delta")
@ -132,8 +204,6 @@ class TestStreamsSubscription:
# Listener running; now close and ensure no crash
sub.close()
# After close, receive should raise SubscriptionClosedError
from libs.broadcast_channel.exc import SubscriptionClosedError
with pytest.raises(SubscriptionClosedError):
sub.receive()
@ -143,3 +213,141 @@ class TestStreamsSubscription:
topic.publish(b"payload")
# No expire recorded when retention is disabled
assert fake_redis._expire_calls.get("stream:zeta") is None
@pytest.mark.parametrize(
("case"),
build_listen_payload_cases(),
ids=lambda case: cast(ListenPayloadCase, case).name,
)
def test_listener_normalizes_supported_payloads_and_ignores_unsupported_shapes(self, case: ListenPayloadCase):
class OneShotRedis:
def __init__(self, fields: object) -> None:
self._fields = fields
self._calls = 0
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
self._calls += 1
if self._calls == 1:
key = next(iter(streams))
return [(key, [("1-0", self._fields)])]
subscription._closed.set()
return []
subscription = _StreamsSubscription(OneShotRedis(case.fields), "stream:payload-shape")
subscription._listen()
received: list[bytes] = []
while not subscription._queue.empty():
item = subscription._queue.get_nowait()
if item is subscription._SENTINEL:
break
received.append(bytes(item))
assert received == case.expected_messages
assert subscription._last_id == "1-0"
def test_iterator_yields_messages_until_subscription_is_closed(self, streams_channel: StreamsBroadcastChannel):
topic = streams_channel.topic("iter")
subscription = topic.subscribe()
iterator = iter(subscription)
def publish_later() -> None:
time.sleep(0.05)
topic.publish(b"iter-message")
publisher = threading.Thread(target=publish_later, daemon=True)
publisher.start()
assert next(iterator) == b"iter-message"
subscription.close()
publisher.join(timeout=1)
with pytest.raises(StopIteration):
next(iterator)
def test_receive_with_none_timeout_blocks_until_message_arrives(self, streams_channel: StreamsBroadcastChannel):
topic = streams_channel.topic("blocking")
subscription = topic.subscribe()
def publish_later() -> None:
time.sleep(0.05)
topic.publish(b"blocking-message")
publisher = threading.Thread(target=publish_later, daemon=True)
publisher.start()
try:
assert subscription.receive(timeout=None) == b"blocking-message"
finally:
subscription.close()
publisher.join(timeout=1)
def test_receive_raises_when_queue_contains_close_sentinel(self):
subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:sentinel")
subscription._listener = threading.current_thread()
subscription._queue.put_nowait(subscription._SENTINEL)
with pytest.raises(SubscriptionClosedError):
subscription.receive(timeout=0.01)
def test_close_before_listener_starts_is_a_noop(self):
subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:not-started")
subscription.close()
assert subscription._listener is None
with pytest.raises(SubscriptionClosedError):
subscription.receive(timeout=0.01)
def test_start_if_needed_returns_immediately_for_closed_subscription(self):
subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:already-closed")
subscription._closed.set()
subscription._start_if_needed()
assert subscription._listener is None
def test_iterator_skips_none_results_and_keeps_polling(self):
subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:iterator-none")
items = iter([None, b"event"])
subscription._start_if_needed = lambda: None # type: ignore[method-assign]
def fake_receive(timeout: float | None = 0.1) -> bytes | None:
value = next(items)
if value is not None:
subscription._closed.set()
return value
subscription.receive = fake_receive # type: ignore[method-assign]
assert next(iter(subscription)) == b"event"
def test_close_logs_warning_when_listener_does_not_stop_in_time(
self,
caplog: pytest.LogCaptureFixture,
):
blocking_redis = BlockingRedis()
subscription = _StreamsSubscription(blocking_redis, "stream:slow-close")
subscription._start_if_needed()
listener = subscription._listener
assert listener is not None
original_join = listener.join
original_is_alive = listener.is_alive
def delayed_join(timeout: float | None = None) -> None:
original_join(0.01)
listener.join = delayed_join # type: ignore[method-assign]
listener.is_alive = lambda: True # type: ignore[method-assign]
try:
subscription.close()
assert "did not stop within timeout" in caplog.text
finally:
listener.join = original_join # type: ignore[method-assign]
listener.is_alive = original_is_alive # type: ignore[method-assign]
blocking_redis.release()
original_join(timeout=1)

View File

@ -15,6 +15,7 @@ from datetime import UTC, datetime
from unittest.mock import patch
from uuid import uuid4
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models.dataset import (
AppDatasetJoin,
ChildChunk,
@ -67,14 +68,14 @@ class TestDatasetModelValidation:
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
description="Test description",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model="text-embedding-ada-002",
embedding_model_provider="openai",
)
# Assert
assert dataset.description == "Test description"
assert dataset.indexing_technique == "high_quality"
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert dataset.embedding_model == "text-embedding-ada-002"
assert dataset.embedding_model_provider == "openai"
@ -86,21 +87,21 @@ class TestDatasetModelValidation:
name="High Quality Dataset",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
dataset_economy = Dataset(
tenant_id=str(uuid4()),
name="Economy Dataset",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
indexing_technique="economy",
indexing_technique=IndexTechniqueType.ECONOMY,
)
# Assert
assert dataset_high_quality.indexing_technique == "high_quality"
assert dataset_economy.indexing_technique == "economy"
assert "high_quality" in Dataset.INDEXING_TECHNIQUE_LIST
assert "economy" in Dataset.INDEXING_TECHNIQUE_LIST
assert dataset_high_quality.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert dataset_economy.indexing_technique == IndexTechniqueType.ECONOMY
assert IndexTechniqueType.HIGH_QUALITY in Dataset.INDEXING_TECHNIQUE_LIST
assert IndexTechniqueType.ECONOMY in Dataset.INDEXING_TECHNIQUE_LIST
def test_dataset_provider_validation(self):
"""Test dataset provider values."""
@ -983,7 +984,7 @@ class TestModelIntegration:
name="Test Dataset",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
dataset.id = dataset_id
@ -1019,7 +1020,7 @@ class TestModelIntegration:
assert document.dataset_id == dataset_id
assert segment.dataset_id == dataset_id
assert segment.document_id == document_id
assert dataset.indexing_technique == "high_quality"
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert document.word_count == 100
assert segment.status == SegmentStatus.COMPLETED

View File

@ -97,6 +97,7 @@ from unittest.mock import Mock, create_autospec, patch
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models import Account, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
@ -149,7 +150,7 @@ class DatasetUpdateDeleteTestDataFactory:
name: str = "Test Dataset",
description: str = "Test description",
tenant_id: str = "tenant-123",
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider: str | None = "openai",
embedding_model: str | None = "text-embedding-ada-002",
collection_binding_id: str | None = "binding-123",
@ -237,7 +238,7 @@ class DatasetUpdateDeleteTestDataFactory:
@staticmethod
def create_knowledge_configuration_mock(
chunk_structure: str = "tree",
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider: str = "openai",
embedding_model: str = "text-embedding-ada-002",
keyword_number: int = 10,
@ -630,12 +631,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
dataset_id="dataset-123",
runtime_mode="rag_pipeline",
chunk_structure="tree",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock(
chunk_structure="list",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
)
@ -671,7 +672,7 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
# Assert
assert dataset.chunk_structure == "list"
assert dataset.indexing_technique == "high_quality"
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert dataset.embedding_model == "text-embedding-ada-002"
assert dataset.embedding_model_provider == "openai"
assert dataset.collection_binding_id == "binding-123"
@ -698,12 +699,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
dataset_id="dataset-123",
runtime_mode="rag_pipeline",
chunk_structure="tree", # Existing structure
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock(
chunk_structure="list", # Different structure
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
mock_session.merge.return_value = dataset
@ -735,11 +736,11 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(
dataset_id="dataset-123",
runtime_mode="rag_pipeline",
indexing_technique="high_quality", # Current technique
indexing_technique=IndexTechniqueType.HIGH_QUALITY, # Current technique
)
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock(
indexing_technique="economy", # Trying to change to economy
indexing_technique=IndexTechniqueType.ECONOMY, # Trying to change to economy
)
mock_session.merge.return_value = dataset

View File

@ -111,7 +111,7 @@ from unittest.mock import Mock, patch
import pytest
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.dataset import Dataset, DatasetProcessRule, Document
from services.dataset_service import DatasetService, DocumentService
@ -154,7 +154,7 @@ class DocumentValidationTestDataFactory:
dataset_id: str = "dataset-123",
tenant_id: str = "tenant-123",
doc_form: str | None = None,
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider: str = "openai",
embedding_model: str = "text-embedding-ada-002",
**kwargs,
@ -190,7 +190,7 @@ class DocumentValidationTestDataFactory:
data_source: DataSource | None = None,
process_rule: ProcessRule | None = None,
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
**kwargs,
) -> Mock:
"""
@ -448,7 +448,7 @@ class TestDatasetServiceCheckDatasetModelSetting:
"""
# Arrange
dataset = DocumentValidationTestDataFactory.create_dataset_mock(
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
)
@ -481,7 +481,7 @@ class TestDatasetServiceCheckDatasetModelSetting:
- No errors are raised
"""
# Arrange
dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
# Act (should not raise)
DatasetService.check_dataset_model_setting(dataset)
@ -503,7 +503,7 @@ class TestDatasetServiceCheckDatasetModelSetting:
"""
# Arrange
dataset = DocumentValidationTestDataFactory.create_dataset_mock(
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="invalid-model",
)
@ -533,7 +533,7 @@ class TestDatasetServiceCheckDatasetModelSetting:
"""
# Arrange
dataset = DocumentValidationTestDataFactory.create_dataset_mock(
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
)

View File

@ -2,7 +2,7 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models.account import Account
from models.dataset import ChildChunk, Dataset, Document, DocumentSegment
from models.enums import SegmentType
@ -111,7 +111,7 @@ class SegmentTestDataFactory:
def create_dataset_mock(
dataset_id: str = "dataset-123",
tenant_id: str = "tenant-123",
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
embedding_model: str = "text-embedding-ada-002",
embedding_model_provider: str = "openai",
**kwargs,
@ -163,7 +163,7 @@ class TestSegmentServiceCreateSegment:
"""Test successful creation of a segment."""
# Arrange
document = SegmentTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
args = {"content": "New segment content", "keywords": ["test", "segment"]}
mock_query = MagicMock()
@ -212,7 +212,7 @@ class TestSegmentServiceCreateSegment:
"""Test creation of segment with QA model (requires answer)."""
# Arrange
document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]}
mock_query = MagicMock()
@ -247,7 +247,7 @@ class TestSegmentServiceCreateSegment:
"""Test creation of segment with high quality indexing technique."""
# Arrange
document = SegmentTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
args = {"content": "New segment content", "keywords": ["test"]}
mock_query = MagicMock()
@ -289,7 +289,7 @@ class TestSegmentServiceCreateSegment:
"""Test segment creation when vector indexing fails."""
# Arrange
document = SegmentTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
args = {"content": "New segment content", "keywords": ["test"]}
mock_query = MagicMock()
@ -342,7 +342,7 @@ class TestSegmentServiceUpdateSegment:
# Arrange
segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10)
document = SegmentTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
args = SegmentUpdateArgs(content="Updated content", keywords=["updated"])
mock_db_session.query.return_value.where.return_value.first.return_value = segment
@ -431,7 +431,7 @@ class TestSegmentServiceUpdateSegment:
# Arrange
segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10)
document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"])
mock_db_session.query.return_value.where.return_value.first.return_value = segment

View File

@ -1,214 +0,0 @@
"""
Unit tests for services.advanced_prompt_template_service
"""
import copy
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_CONTEXT,
CHAT_APP_CHAT_PROMPT_CONFIG,
CHAT_APP_COMPLETION_PROMPT_CONFIG,
COMPLETION_APP_CHAT_PROMPT_CONFIG,
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
CONTEXT,
)
from models.model import AppMode
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class TestAdvancedPromptTemplateService:
"""Test suite for AdvancedPromptTemplateService."""
def test_get_prompt_should_use_baichuan_prompt_when_model_name_contains_baichuan(self) -> None:
"""Test baichuan model names use baichuan context prompt."""
# Arrange
args = {
"app_mode": AppMode.CHAT,
"model_mode": "chat",
"model_name": "Baichuan2-13B",
"has_context": "true",
}
# Act
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert
assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT)
def test_get_prompt_should_use_common_prompt_when_model_name_not_baichuan(self) -> None:
"""Test non-baichuan model names use common prompt."""
# Arrange
args = {
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "gpt-4",
"has_context": "false",
}
original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert
assert result == original_config
assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG
def test_get_common_prompt_should_return_empty_dict_when_app_mode_invalid(self) -> None:
"""Test invalid app mode returns empty dict."""
# Arrange
app_mode = "invalid"
model_mode = "chat"
# Act
result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "true")
# Assert
assert result == {}
def test_get_common_prompt_should_prepend_context_for_completion_prompt(self) -> None:
"""Test context is prepended for completion prompt when has_context is true."""
# Arrange
original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true")
# Assert
assert result["completion_prompt_config"]["prompt"]["text"].startswith(CONTEXT)
assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG
def test_get_common_prompt_should_prepend_context_for_chat_prompt(self) -> None:
"""Test context is prepended for chat prompt when has_context is true."""
# Arrange
original_config = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true")
# Assert
assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(CONTEXT)
assert original_config == COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_get_common_prompt_should_return_chat_prompt_without_context_when_has_context_false(self) -> None:
"""Test chat prompt remains unchanged when has_context is false."""
# Arrange
original_config = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "false")
# Assert
assert result == original_config
assert original_config == CHAT_APP_CHAT_PROMPT_CONFIG
def test_get_common_prompt_should_return_completion_prompt_for_completion_app_mode(self) -> None:
"""Test completion app mode with completion model returns completion prompt."""
# Arrange
original_config = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "false")
# Assert
assert result == original_config
assert original_config == COMPLETION_APP_COMPLETION_PROMPT_CONFIG
def test_get_common_prompt_should_return_empty_dict_when_model_mode_invalid(self) -> None:
"""Test invalid model mode returns empty dict."""
# Arrange
app_mode = AppMode.CHAT
model_mode = "invalid"
# Act
result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "false")
# Assert
assert result == {}
def test_get_completion_prompt_should_not_prepend_context_when_has_context_false(self) -> None:
"""Test helper keeps completion prompt unchanged when context is disabled."""
# Arrange
prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
original_text = prompt_template["completion_prompt_config"]["prompt"]["text"]
# Act
result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT)
# Assert
assert result["completion_prompt_config"]["prompt"]["text"] == original_text
def test_get_chat_prompt_should_not_prepend_context_when_has_context_false(self) -> None:
"""Test helper keeps chat prompt unchanged when context is disabled."""
# Arrange
prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"]
# Act
result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT)
# Assert
assert result["chat_prompt_config"]["prompt"][0]["text"] == original_text
def test_get_baichuan_prompt_should_return_chat_completion_config_when_chat_completion(self) -> None:
"""Test baichuan chat/completion returns the expected config."""
# Arrange
original_config = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false")
# Assert
assert result == original_config
assert original_config == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG
def test_get_baichuan_prompt_should_return_completion_chat_config_when_completion_chat(self) -> None:
"""Test baichuan completion/chat returns the expected config."""
# Arrange
original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "false")
# Assert
assert result == original_config
assert original_config == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_get_baichuan_prompt_should_return_completion_completion_config_when_enabled_context(self) -> None:
"""Test baichuan completion/completion prepends baichuan context when enabled."""
# Arrange
original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true")
# Assert
assert result["completion_prompt_config"]["prompt"]["text"].startswith(BAICHUAN_CONTEXT)
assert original_config == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
def test_get_baichuan_prompt_should_return_chat_chat_config_when_enabled_context(self) -> None:
"""Test baichuan chat/chat prepends baichuan context when enabled."""
# Arrange
original_config = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true")
# Assert
assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT)
assert original_config == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG
def test_get_baichuan_prompt_should_return_empty_dict_when_invalid_inputs(self) -> None:
"""Test invalid baichuan mode combinations return empty dict."""
# Arrange
app_mode = "invalid"
model_mode = "invalid"
# Act
result = AdvancedPromptTemplateService.get_baichuan_prompt(app_mode, model_mode, "true")
# Assert
assert result == {}

View File

@ -4,7 +4,7 @@ from unittest.mock import Mock, create_autospec
import pytest
from redis.exceptions import LockNotOwnedError
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models.account import Account
from models.dataset import Dataset, Document
from services.dataset_service import DocumentService, SegmentService
@ -71,7 +71,7 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned(
dataset.id = "ds-1"
dataset.tenant_id = fake_current_user.current_tenant_id
dataset.data_source_type = "upload_file"
dataset.indexing_technique = "high_quality" # so we skip re-initialization branch
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY # so we skip re-initialization branch
# Minimal knowledge_config stub that satisfies pre-lock code
info_list = types.SimpleNamespace(data_source_type="upload_file")
@ -80,7 +80,7 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned(
doc_form=IndexStructureType.QA_INDEX,
original_document_id=None, # go into "new document" branch
data_source=data_source,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model=None,
embedding_model_provider=None,
retrieval_model=None,
@ -126,7 +126,7 @@ def test_add_segment_ignores_lock_not_owned(
dataset = create_autospec(Dataset, instance=True)
dataset.id = "ds-1"
dataset.tenant_id = fake_current_user.current_tenant_id
dataset.indexing_technique = "economy" # skip embedding/token calculation branch
dataset.indexing_technique = IndexTechniqueType.ECONOMY # skip embedding/token calculation branch
document = create_autospec(Document, instance=True)
document.id = "doc-1"
@ -169,7 +169,7 @@ def test_multi_create_segment_ignores_lock_not_owned(
dataset = create_autospec(Dataset, instance=True)
dataset.id = "ds-1"
dataset.tenant_id = fake_current_user.current_tenant_id
dataset.indexing_technique = "economy" # again, skip high_quality path
dataset.indexing_technique = IndexTechniqueType.ECONOMY # again, skip high_quality path
document = create_autospec(Document, instance=True)
document.id = "doc-1"

View File

@ -11,7 +11,7 @@ from unittest.mock import MagicMock
import pytest
import services.summary_index_service as summary_module
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models.enums import SegmentStatus, SummaryStatus
from services.summary_index_service import SummaryIndexService
@ -27,7 +27,7 @@ class _SessionContext:
return None
def _dataset(*, indexing_technique: str = "high_quality") -> MagicMock:
def _dataset(*, indexing_technique: str = IndexTechniqueType.HIGH_QUALITY) -> MagicMock:
dataset = MagicMock(name="dataset")
dataset.id = "dataset-1"
dataset.tenant_id = "tenant-1"
@ -169,7 +169,8 @@ def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> N
def test_vectorize_summary_skips_non_high_quality(monkeypatch: pytest.MonkeyPatch) -> None:
vector_cls = MagicMock()
monkeypatch.setattr(summary_module, "Vector", vector_cls)
SummaryIndexService.vectorize_summary(_summary_record(), _segment(), _dataset(indexing_technique="economy"))
dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY)
SummaryIndexService.vectorize_summary(_summary_record(), _segment(), dataset)
vector_cls.assert_not_called()
@ -621,7 +622,7 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo
def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _dataset(indexing_technique="economy")
dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY)
document = MagicMock(spec=summary_module.DatasetDocument)
document.id = "doc-1"
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
@ -778,7 +779,7 @@ def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mo
def test_enable_summaries_for_segments_skips_non_high_quality() -> None:
SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique="economy"))
SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique=IndexTechniqueType.ECONOMY))
def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pytest.MonkeyPatch) -> None:
@ -932,9 +933,8 @@ def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mon
def test_update_summary_for_segment_skip_conditions() -> None:
assert (
SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None
)
economy_dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY)
assert SummaryIndexService.update_summary_for_segment(_segment(), economy_dataset, "x") is None
seg = _segment(has_document=True)
seg.document.doc_form = IndexStructureType.QA_INDEX
assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None

View File

@ -9,7 +9,7 @@ from unittest.mock import MagicMock
import pytest
import services.vector_service as vector_service_module
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from services.vector_service import VectorService
@ -32,7 +32,7 @@ class _ParentDocStub:
def _make_dataset(
*,
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
tenant_id: str = "tenant-1",
dataset_id: str = "dataset-1",
@ -192,7 +192,7 @@ def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_ex
dataset = _make_dataset(
doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX,
embedding_model_provider="openai",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
segment = _make_segment()
@ -241,7 +241,7 @@ def test_create_segments_vector_parent_child_uses_default_embedding_model_when_p
dataset = _make_dataset(
doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX,
embedding_model_provider=None,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
segment = _make_segment()
@ -329,7 +329,7 @@ def test_create_segments_vector_parent_child_missing_processing_rule_raises(monk
def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(
doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX,
indexing_technique="economy",
indexing_technique=IndexTechniqueType.ECONOMY,
)
segment = _make_segment()
dataset_document = MagicMock()
@ -348,7 +348,7 @@ def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch
def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="high_quality")
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
segment = _make_segment()
vector_instance = MagicMock()
@ -364,7 +364,7 @@ def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.Monk
def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="economy")
dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY)
segment = _make_segment()
keyword_instance = MagicMock()
@ -380,7 +380,7 @@ def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypat
def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="economy")
dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY)
segment = _make_segment()
keyword_instance = MagicMock()
@ -473,7 +473,7 @@ def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest
def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="high_quality")
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
child_chunk = MagicMock()
child_chunk.content = "child"
child_chunk.index_node_id = "id"
@ -489,7 +489,7 @@ def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.M
def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="economy")
dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY)
vector_cls = MagicMock()
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
@ -505,7 +505,7 @@ def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch)
def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="high_quality")
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
new_chunk = MagicMock()
new_chunk.content = "n"
@ -536,7 +536,7 @@ def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pyte
def test_update_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="economy")
dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY)
vector_cls = MagicMock()
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
VectorService.update_child_chunk_vector([], [], [], dataset)
@ -561,7 +561,7 @@ def test_delete_child_chunk_vector_deletes_by_id(monkeypatch: pytest.MonkeyPatch
def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="economy", is_multimodal=True)
dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY, is_multimodal=True)
segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}])
vector_cls = MagicMock()
@ -575,7 +575,7 @@ def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pyt
def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True)
segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}, {"id": "b"}])
vector_cls = MagicMock()
@ -591,7 +591,7 @@ def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pyt
def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids(
monkeypatch: pytest.MonkeyPatch,
) -> None:
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True)
segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}, {"id": "old-2"}])
vector_instance = MagicMock(name="vector_instance")
@ -612,7 +612,7 @@ def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids(
def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True)
segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}])
vector_instance = MagicMock()
@ -630,7 +630,7 @@ def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch
def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_upload_files(
monkeypatch: pytest.MonkeyPatch,
) -> None:
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True)
segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}])
vector_instance = MagicMock()
@ -663,7 +663,7 @@ def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_up
def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops(
monkeypatch: pytest.MonkeyPatch,
) -> None:
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=False)
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=False)
segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}])
vector_instance = MagicMock()
@ -683,7 +683,7 @@ def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops
def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True)
segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}])
vector_instance = MagicMock()

View File

@ -1,379 +0,0 @@
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from werkzeug.exceptions import NotFound, Unauthorized
from models import Account, AccountStatus
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
ACCOUNT_LOOKUP_PATH = "services.webapp_auth_service.AccountService.get_account_by_email_with_case_fallback"
TOKEN_GENERATE_PATH = "services.webapp_auth_service.TokenManager.generate_token"
TOKEN_GET_DATA_PATH = "services.webapp_auth_service.TokenManager.get_token_data"
def _account(**kwargs: Any) -> Account:
return cast(Account, SimpleNamespace(**kwargs))
@pytest.fixture
def mock_db(mocker: MockerFixture) -> MagicMock:
# Arrange
mocked_db = mocker.patch("services.webapp_auth_service.db")
mocked_db.session = MagicMock()
return mocked_db
def test_authenticate_should_raise_account_not_found_when_email_does_not_exist(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None)
# Act + Assert
with pytest.raises(AccountNotFoundError):
WebAppAuthService.authenticate("user@example.com", "pwd")
def test_authenticate_should_raise_account_login_error_when_account_is_banned(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.BANNED, password="hash", password_salt="salt")
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
# Act + Assert
with pytest.raises(AccountLoginError, match="Account is banned"):
WebAppAuthService.authenticate("user@example.com", "pwd")
@pytest.mark.parametrize("password_value", [None, "hash"])
def test_authenticate_should_raise_password_error_when_password_is_invalid(
password_value: str | None,
mocker: MockerFixture,
) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.ACTIVE, password=password_value, password_salt="salt")
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
mocker.patch("services.webapp_auth_service.compare_password", return_value=False)
# Act + Assert
with pytest.raises(AccountPasswordError, match="Invalid email or password"):
WebAppAuthService.authenticate("user@example.com", "pwd")
def test_authenticate_should_return_account_when_credentials_are_valid(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.ACTIVE, password="hash", password_salt="salt")
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
mocker.patch("services.webapp_auth_service.compare_password", return_value=True)
# Act
result = WebAppAuthService.authenticate("user@example.com", "pwd")
# Assert
assert result is account
def test_login_should_return_token_from_internal_token_builder(mocker: MockerFixture) -> None:
# Arrange
account = _account(id="a1", email="u@example.com")
mock_get_token = mocker.patch.object(WebAppAuthService, "_get_account_jwt_token", return_value="jwt-token")
# Act
result = WebAppAuthService.login(account)
# Assert
assert result == "jwt-token"
mock_get_token.assert_called_once_with(account=account)
def test_get_user_through_email_should_return_none_when_account_not_found(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None)
# Act
result = WebAppAuthService.get_user_through_email("missing@example.com")
# Assert
assert result is None
def test_get_user_through_email_should_raise_unauthorized_when_account_banned(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.BANNED)
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
# Act + Assert
with pytest.raises(Unauthorized, match="Account is banned"):
WebAppAuthService.get_user_through_email("user@example.com")
def test_get_user_through_email_should_return_account_when_active(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.ACTIVE)
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
# Act
result = WebAppAuthService.get_user_through_email("user@example.com")
# Assert
assert result is account
def test_send_email_code_login_email_should_raise_error_when_email_not_provided() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Email must be provided"):
WebAppAuthService.send_email_code_login_email(account=None, email=None)
def test_send_email_code_login_email_should_generate_token_and_send_mail_for_account(
mocker: MockerFixture,
) -> None:
# Arrange
account = _account(email="user@example.com")
mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[1, 2, 3, 4, 5, 6])
mock_generate_token = mocker.patch(TOKEN_GENERATE_PATH, return_value="token-1")
mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay")
# Act
result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US")
# Assert
assert result == "token-1"
mock_generate_token.assert_called_once()
assert mock_generate_token.call_args.kwargs["additional_data"] == {"code": "123456"}
mock_delay.assert_called_once_with(language="en-US", to="user@example.com", code="123456")
def test_send_email_code_login_email_should_send_mail_for_email_without_account(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[0, 0, 0, 0, 0, 0])
mocker.patch(TOKEN_GENERATE_PATH, return_value="token-2")
mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay")
# Act
result = WebAppAuthService.send_email_code_login_email(account=None, email="alt@example.com", language="zh-Hans")
# Assert
assert result == "token-2"
mock_delay.assert_called_once_with(language="zh-Hans", to="alt@example.com", code="000000")
def test_get_email_code_login_data_should_delegate_to_token_manager(mocker: MockerFixture) -> None:
# Arrange
mock_get_data = mocker.patch(TOKEN_GET_DATA_PATH, return_value={"code": "123"})
# Act
result = WebAppAuthService.get_email_code_login_data("token-abc")
# Assert
assert result == {"code": "123"}
mock_get_data.assert_called_once_with("token-abc", "email_code_login")
def test_revoke_email_code_login_token_should_delegate_to_token_manager(mocker: MockerFixture) -> None:
# Arrange
mock_revoke = mocker.patch("services.webapp_auth_service.TokenManager.revoke_token")
# Act
WebAppAuthService.revoke_email_code_login_token("token-xyz")
# Assert
mock_revoke.assert_called_once_with("token-xyz", "email_code_login")
def test_create_end_user_should_raise_not_found_when_site_does_not_exist(mock_db: MagicMock) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(NotFound, match="Site not found"):
WebAppAuthService.create_end_user("app-code", "user@example.com")
def test_create_end_user_should_raise_not_found_when_app_does_not_exist(mock_db: MagicMock) -> None:
# Arrange
site = SimpleNamespace(app_id="app-1")
app_query = MagicMock()
app_query.where.return_value.first.return_value = None
mock_db.session.query.return_value.where.return_value.first.side_effect = [site, None]
# Act + Assert
with pytest.raises(NotFound, match="App not found"):
WebAppAuthService.create_end_user("app-code", "user@example.com")
def test_create_end_user_should_create_and_commit_end_user_when_data_is_valid(mock_db: MagicMock) -> None:
# Arrange
site = SimpleNamespace(app_id="app-1")
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
mock_db.session.query.return_value.where.return_value.first.side_effect = [site, app_model]
# Act
result = WebAppAuthService.create_end_user("app-code", "user@example.com")
# Assert
assert result.tenant_id == "tenant-1"
assert result.app_id == "app-1"
assert result.session_id == "user@example.com"
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_get_account_jwt_token_should_build_payload_and_issue_token(mocker: MockerFixture) -> None:
# Arrange
account = _account(id="a1", email="user@example.com")
mocker.patch("services.webapp_auth_service.dify_config.ACCESS_TOKEN_EXPIRE_MINUTES", 60)
mock_issue = mocker.patch("services.webapp_auth_service.PassportService.issue", return_value="jwt-1")
# Act
token = WebAppAuthService._get_account_jwt_token(account)
# Assert
assert token == "jwt-1"
payload = mock_issue.call_args.args[0]
assert payload["user_id"] == "a1"
assert payload["session_id"] == "user@example.com"
assert payload["token_source"] == "webapp_login_token"
assert payload["auth_type"] == "internal"
assert payload["exp"] > int(datetime.now(UTC).timestamp())
@pytest.mark.parametrize(
("access_mode", "expected"),
[
("private", True),
("private_all", True),
("public", False),
],
)
def test_is_app_require_permission_check_should_use_access_mode_when_provided(
access_mode: str,
expected: bool,
) -> None:
# Arrange
# Act
result = WebAppAuthService.is_app_require_permission_check(access_mode=access_mode)
# Assert
assert result is expected
def test_is_app_require_permission_check_should_raise_when_no_identifier_provided() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Either app_code or app_id must be provided"):
WebAppAuthService.is_app_require_permission_check()
def test_is_app_require_permission_check_should_raise_when_app_id_cannot_be_determined(mocker: MockerFixture) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value=None)
# Act + Assert
with pytest.raises(ValueError, match="App ID could not be determined"):
WebAppAuthService.is_app_require_permission_check(app_code="app-code")
def test_is_app_require_permission_check_should_return_true_when_enterprise_mode_requires_it(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1")
mocker.patch(
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
return_value=SimpleNamespace(access_mode="private"),
)
# Act
result = WebAppAuthService.is_app_require_permission_check(app_code="app-code")
# Assert
assert result is True
def test_is_app_require_permission_check_should_return_false_when_enterprise_settings_do_not_require_it(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
return_value=SimpleNamespace(access_mode="public"),
)
# Act
result = WebAppAuthService.is_app_require_permission_check(app_id="app-1")
# Assert
assert result is False
@pytest.mark.parametrize(
("access_mode", "expected"),
[
("public", WebAppAuthType.PUBLIC),
("private", WebAppAuthType.INTERNAL),
("private_all", WebAppAuthType.INTERNAL),
("sso_verified", WebAppAuthType.EXTERNAL),
],
)
def test_get_app_auth_type_should_map_access_modes_correctly(
access_mode: str,
expected: WebAppAuthType,
) -> None:
# Arrange
# Act
result = WebAppAuthService.get_app_auth_type(access_mode=access_mode)
# Assert
assert result == expected
def test_get_app_auth_type_should_resolve_from_app_code(mocker: MockerFixture) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1")
mocker.patch(
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
return_value=SimpleNamespace(access_mode="private_all"),
)
# Act
result = WebAppAuthService.get_app_auth_type(app_code="app-code")
# Assert
assert result == WebAppAuthType.INTERNAL
def test_get_app_auth_type_should_raise_when_no_input_provided() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Either app_code or access_mode must be provided"):
WebAppAuthService.get_app_auth_type()
def test_get_app_auth_type_should_raise_when_cannot_determine_type_from_invalid_mode() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Could not determine app authentication type"):
WebAppAuthService.get_app_auth_type(access_mode="unknown")

View File

@ -121,7 +121,7 @@ import pytest
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.models.document import Document
from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment
from services.vector_service import VectorService
@ -153,7 +153,7 @@ class VectorServiceTestDataFactory:
dataset_id: str = "dataset-123",
tenant_id: str = "tenant-123",
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider: str = "openai",
embedding_model: str = "text-embedding-ada-002",
index_struct_dict: dict | None = None,
@ -494,7 +494,7 @@ class TestVectorService:
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(
doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique="high_quality"
doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique=IndexTechniqueType.HIGH_QUALITY
)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -535,7 +535,7 @@ class TestVectorService:
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(
doc_form="parent_child_model", indexing_technique="high_quality"
doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY
)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -568,7 +568,7 @@ class TestVectorService:
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(
doc_form="parent_child_model", indexing_technique="high_quality"
doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY
)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -591,7 +591,7 @@ class TestVectorService:
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(
doc_form="parent_child_model", indexing_technique="high_quality"
doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY
)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -616,7 +616,7 @@ class TestVectorService:
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(
doc_form="parent_child_model", indexing_technique="economy"
doc_form="parent_child_model", indexing_technique=IndexTechniqueType.ECONOMY
)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -669,7 +669,7 @@ class TestVectorService:
store when using high_quality indexing.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -695,7 +695,7 @@ class TestVectorService:
index when using economy indexing with keywords.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -731,7 +731,7 @@ class TestVectorService:
index when using economy indexing without keywords.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -895,7 +895,7 @@ class TestVectorService:
when using high_quality indexing.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
@ -923,7 +923,7 @@ class TestVectorService:
using economy indexing.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
@ -951,7 +951,7 @@ class TestVectorService:
when there are new chunks, updated chunks, and deleted chunks.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock(chunk_id="new-chunk-1")
@ -993,7 +993,7 @@ class TestVectorService:
add_texts is called, not delete_by_ids.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
@ -1019,7 +1019,7 @@ class TestVectorService:
delete_by_ids is called, not add_texts.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
delete_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
@ -1045,7 +1045,7 @@ class TestVectorService:
using economy indexing.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
@ -1075,7 +1075,7 @@ class TestVectorService:
when using high_quality indexing.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
@ -1099,7 +1099,7 @@ class TestVectorService:
using economy indexing.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()

View File

@ -16,7 +16,7 @@ from unittest.mock import MagicMock, patch
import pytest
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models.enums import DataSourceType
from tasks.clean_dataset_task import clean_dataset_task
@ -184,7 +184,7 @@ class TestErrorHandling:
clean_dataset_task(
dataset_id=dataset_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
@ -229,7 +229,7 @@ class TestPipelineAndWorkflowDeletion:
clean_dataset_task(
dataset_id=dataset_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
@ -265,7 +265,7 @@ class TestPipelineAndWorkflowDeletion:
clean_dataset_task(
dataset_id=dataset_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
@ -321,7 +321,7 @@ class TestSegmentAttachmentCleanup:
clean_dataset_task(
dataset_id=dataset_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
@ -366,7 +366,7 @@ class TestSegmentAttachmentCleanup:
clean_dataset_task(
dataset_id=dataset_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
@ -408,7 +408,7 @@ class TestEdgeCases:
clean_dataset_task(
dataset_id=dataset_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
@ -445,7 +445,7 @@ class TestIndexProcessorParameters:
- Dataset object with correct attributes is passed
"""
# Arrange
indexing_technique = "high_quality"
indexing_technique = IndexTechniqueType.HIGH_QUALITY
index_struct = '{"type": "paragraph"}'
# Act

View File

@ -15,7 +15,7 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
from core.indexing_runner import DocumentIsPausedError
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
@ -209,7 +209,7 @@ def mock_dataset(dataset_id, tenant_id):
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.indexing_technique = "high_quality"
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
dataset.embedding_model_provider = "openai"
dataset.embedding_model = "text-embedding-ada-002"
return dataset