Merge remote-tracking branch 'origin/main' into feat/credit-pool

This commit is contained in:
Yansong Zhang
2026-01-08 10:09:09 +08:00
3289 changed files with 213946 additions and 152401 deletions

View File

@ -1294,6 +1294,42 @@ class TestBillingServiceSubscriptionOperations:
# Assert
assert result == {}
def test_get_plan_bulk_with_invalid_tenant_plan_skipped(self, mock_send_request):
"""Test bulk plan retrieval when one tenant has invalid plan data (should skip that tenant)."""
# Arrange
tenant_ids = ["tenant-valid-1", "tenant-invalid", "tenant-valid-2"]
# Response with one invalid tenant plan (missing expiration_date) and two valid ones
mock_send_request.return_value = {
"data": {
"tenant-valid-1": {"plan": "sandbox", "expiration_date": 1735689600},
"tenant-invalid": {"plan": "professional"}, # Missing expiration_date field
"tenant-valid-2": {"plan": "team", "expiration_date": 1767225600},
}
}
# Act
with patch("services.billing_service.logger") as mock_logger:
result = BillingService.get_plan_bulk(tenant_ids)
# Assert - should only contain valid tenants
assert len(result) == 2
assert "tenant-valid-1" in result
assert "tenant-valid-2" in result
assert "tenant-invalid" not in result
# Verify valid tenants have correct data
assert result["tenant-valid-1"]["plan"] == "sandbox"
assert result["tenant-valid-1"]["expiration_date"] == 1735689600
assert result["tenant-valid-2"]["plan"] == "team"
assert result["tenant-valid-2"]["expiration_date"] == 1767225600
# Verify exception was logged for the invalid tenant
mock_logger.exception.assert_called_once()
log_call_args = mock_logger.exception.call_args[0]
assert "get_plan_bulk: failed to validate subscription plan for tenant" in log_call_args[0]
assert "tenant-invalid" in log_call_args[1]
def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request):
"""Test successful retrieval of expired subscription cleanup whitelist."""
# Arrange

View File

@ -0,0 +1,472 @@
"""
Unit tests for SegmentService.get_segments method.
Tests the retrieval of document segments with pagination and filtering:
- Basic pagination (page, limit)
- Status filtering
- Keyword search
- Ordering by position and id (to avoid duplicate data)
"""
from unittest.mock import Mock, create_autospec, patch
import pytest
from models.dataset import DocumentSegment
class SegmentServiceTestDataFactory:
"""
Factory class for creating test data and mock objects for segment tests.
"""
@staticmethod
def create_segment_mock(
segment_id: str = "segment-123",
document_id: str = "doc-123",
tenant_id: str = "tenant-123",
dataset_id: str = "dataset-123",
position: int = 1,
content: str = "Test content",
status: str = "completed",
**kwargs,
) -> Mock:
"""
Create a mock document segment.
Args:
segment_id: Unique identifier for the segment
document_id: Parent document ID
tenant_id: Tenant ID the segment belongs to
dataset_id: Parent dataset ID
position: Position within the document
content: Segment text content
status: Indexing status
**kwargs: Additional attributes
Returns:
Mock: DocumentSegment mock object
"""
segment = create_autospec(DocumentSegment, instance=True)
segment.id = segment_id
segment.document_id = document_id
segment.tenant_id = tenant_id
segment.dataset_id = dataset_id
segment.position = position
segment.content = content
segment.status = status
for key, value in kwargs.items():
setattr(segment, key, value)
return segment
class TestSegmentServiceGetSegments:
"""
Comprehensive unit tests for SegmentService.get_segments method.
Tests cover:
- Basic pagination functionality
- Status list filtering
- Keyword search filtering
- Ordering (position + id for uniqueness)
- Empty results
- Combined filters
"""
@pytest.fixture
def mock_segment_service_dependencies(self):
"""
Common mock setup for segment service dependencies.
Patches:
- db: Database operations and pagination
- select: SQLAlchemy query builder
"""
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.select") as mock_select,
):
yield {
"db": mock_db,
"select": mock_select,
}
def test_get_segments_basic_pagination(self, mock_segment_service_dependencies):
"""
Test basic pagination functionality.
Verifies:
- Query is built with document_id and tenant_id filters
- Pagination uses correct page and limit parameters
- Returns segments and total count
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
page = 1
limit = 20
# Create mock segments
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", position=1, content="First segment"
)
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-2", position=2, content="Second segment"
)
# Mock pagination result
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2]
mock_paginated.total = 2
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
# Mock select builder
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, page=page, limit=limit)
# Assert
assert len(items) == 2
assert total == 2
assert items[0].id == "seg-1"
assert items[1].id == "seg-2"
mock_segment_service_dependencies["db"].paginate.assert_called_once()
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["page"] == page
assert call_kwargs["per_page"] == limit
assert call_kwargs["max_per_page"] == 100
assert call_kwargs["error_out"] is False
def test_get_segments_with_status_filter(self, mock_segment_service_dependencies):
"""
Test filtering by status list.
Verifies:
- Status list filter is applied to query
- Only segments with matching status are returned
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = ["completed", "indexing"]
segment1 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1", status="completed")
segment2 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-2", status="indexing")
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2]
mock_paginated.total = 2
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id, tenant_id=tenant_id, status_list=status_list
)
# Assert
assert len(items) == 2
assert total == 2
# Verify where was called multiple times (base filters + status filter)
assert mock_query.where.call_count >= 2
def test_get_segments_with_empty_status_list(self, mock_segment_service_dependencies):
"""
Test with empty status list.
Verifies:
- Empty status list is handled correctly
- No status filter is applied to avoid WHERE false condition
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = []
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id, tenant_id=tenant_id, status_list=status_list
)
# Assert
assert len(items) == 1
assert total == 1
# Should only be called once (base filters, no status filter)
assert mock_query.where.call_count == 1
def test_get_segments_with_keyword_search(self, mock_segment_service_dependencies):
"""
Test keyword search functionality.
Verifies:
- Keyword filter uses ilike for case-insensitive search
- Search pattern includes wildcards (%keyword%)
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
keyword = "search term"
segment = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", content="This contains search term"
)
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, keyword=keyword)
# Assert
assert len(items) == 1
assert total == 1
# Verify where was called for base filters + keyword filter
assert mock_query.where.call_count == 2
def test_get_segments_ordering_by_position_and_id(self, mock_segment_service_dependencies):
"""
Test ordering by position and id.
Verifies:
- Results are ordered by position ASC
- Results are secondarily ordered by id ASC to ensure uniqueness
- This prevents duplicate data across pages when positions are not unique
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
# Create segments with same position but different ids
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", position=1, content="Content 1"
)
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-2", position=1, content="Content 2"
)
segment3 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-3", position=2, content="Content 3"
)
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2, segment3]
mock_paginated.total = 3
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
# Assert
assert len(items) == 3
assert total == 3
mock_query.order_by.assert_called_once()
def test_get_segments_empty_results(self, mock_segment_service_dependencies):
"""
Test when no segments match the criteria.
Verifies:
- Empty list is returned for items
- Total count is 0
"""
# Arrange
document_id = "non-existent-doc"
tenant_id = "tenant-123"
mock_paginated = Mock()
mock_paginated.items = []
mock_paginated.total = 0
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
# Assert
assert items == []
assert total == 0
def test_get_segments_combined_filters(self, mock_segment_service_dependencies):
"""
Test with multiple filters combined.
Verifies:
- All filters work together correctly
- Status list and keyword search both applied
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = ["completed"]
keyword = "important"
page = 2
limit = 10
segment = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1",
status="completed",
content="This is important information",
)
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
status_list=status_list,
keyword=keyword,
page=page,
limit=limit,
)
# Assert
assert len(items) == 1
assert total == 1
# Verify filters: base + status + keyword
assert mock_query.where.call_count == 3
# Verify pagination parameters
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["page"] == page
assert call_kwargs["per_page"] == limit
def test_get_segments_with_none_status_list(self, mock_segment_service_dependencies):
"""
Test with None status list.
Verifies:
- None status list is handled correctly
- No status filter is applied
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
status_list=None,
)
# Assert
assert len(items) == 1
assert total == 1
# Should only be called once (base filters only, no status filter)
assert mock_query.where.call_count == 1
def test_get_segments_pagination_max_per_page_limit(self, mock_segment_service_dependencies):
"""
Test that max_per_page is correctly set to 100.
Verifies:
- max_per_page parameter is set to 100
- This prevents excessive page sizes
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
limit = 200 # Request more than max_per_page
mock_paginated = Mock()
mock_paginated.items = []
mock_paginated.total = 0
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
limit=limit,
)
# Assert
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["max_per_page"] == 100

View File

@ -27,7 +27,6 @@ def service_with_fake_configurations():
description=None,
icon_small=None,
icon_small_dark=None,
icon_large=None,
background=None,
help=None,
supported_model_types=[ModelType.LLM],