Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-09-13 01:27:37 +08:00
242 changed files with 10968 additions and 2216 deletions

View File

@ -203,6 +203,7 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
# Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5
CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5
OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5

View File

@ -13,7 +13,6 @@ from services.account_service import AccountService, RegisterService, TenantServ
from services.errors.account import (
AccountAlreadyInTenantError,
AccountLoginError,
AccountNotFoundError,
AccountPasswordError,
AccountRegisterError,
CurrentPasswordIncorrectError,
@ -161,7 +160,7 @@ class TestAccountService:
fake = Faker()
email = fake.email()
password = fake.password(length=12)
with pytest.raises(AccountNotFoundError):
with pytest.raises(AccountPasswordError):
AccountService.authenticate(email, password)
def test_authenticate_banned_account(self, db_session_with_containers, mock_external_service_dependencies):

View File

@ -42,7 +42,7 @@ class TestAdvancedPromptTemplateService:
# Test data for Baichuan model
args = {
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
@ -77,7 +77,7 @@ class TestAdvancedPromptTemplateService:
# Test data for common model
args = {
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
@ -116,7 +116,7 @@ class TestAdvancedPromptTemplateService:
for model_name in test_cases:
args = {
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": model_name,
"has_context": "true",
@ -144,7 +144,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "true")
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
@ -173,7 +173,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "chat", "true")
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
@ -202,7 +202,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "completion", "true")
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
@ -230,7 +230,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "chat", "true")
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
@ -257,7 +257,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "false")
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "false")
# Assert: Verify the expected outcomes
assert result is not None
@ -303,7 +303,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "unsupported_mode", "true")
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "unsupported_mode", "true")
# Assert: Verify empty dict is returned
assert result == {}
@ -442,7 +442,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "true")
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
@ -473,7 +473,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "chat", "true")
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
@ -502,7 +502,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "completion", "true")
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
@ -530,7 +530,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "chat", "true")
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
@ -557,7 +557,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "false")
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false")
# Assert: Verify the expected outcomes
assert result is not None
@ -603,7 +603,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "unsupported_mode", "true")
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "unsupported_mode", "true")
# Assert: Verify empty dict is returned
assert result == {}
@ -621,7 +621,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()
# Test all app modes
app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value]
app_modes = [AppMode.CHAT, AppMode.COMPLETION]
model_modes = ["completion", "chat"]
for app_mode in app_modes:
@ -653,7 +653,7 @@ class TestAdvancedPromptTemplateService:
fake = Faker()
# Test all app modes
app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value]
app_modes = [AppMode.CHAT, AppMode.COMPLETION]
model_modes = ["completion", "chat"]
for app_mode in app_modes:
@ -686,10 +686,10 @@ class TestAdvancedPromptTemplateService:
# Test edge cases
edge_cases = [
{"app_mode": "", "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true"},
{"app_mode": AppMode.CHAT.value, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"},
{"app_mode": AppMode.CHAT.value, "model_mode": "completion", "model_name": "", "has_context": "true"},
{"app_mode": AppMode.CHAT, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"},
{"app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "", "has_context": "true"},
{
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "",
@ -723,7 +723,7 @@ class TestAdvancedPromptTemplateService:
# Test with context
args = {
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
@ -757,7 +757,7 @@ class TestAdvancedPromptTemplateService:
# Test with context
args = {
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
@ -786,25 +786,25 @@ class TestAdvancedPromptTemplateService:
# Test different scenarios
test_scenarios = [
{
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
{
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "chat",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"app_mode": AppMode.COMPLETION,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"app_mode": AppMode.COMPLETION,
"model_mode": "chat",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
@ -843,25 +843,25 @@ class TestAdvancedPromptTemplateService:
# Test different scenarios
test_scenarios = [
{
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
{
"app_mode": AppMode.CHAT.value,
"app_mode": AppMode.CHAT,
"model_mode": "chat",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"app_mode": AppMode.COMPLETION,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"app_mode": AppMode.COMPLETION,
"model_mode": "chat",
"model_name": "baichuan-13b-chat",
"has_context": "true",

View File

@ -255,7 +255,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Try to create metadata with built-in field name
built_in_field_name = BuiltInField.document_name.value
built_in_field_name = BuiltInField.document_name
metadata_args = MetadataArgs(type="string", name=built_in_field_name)
# Act & Assert: Verify proper error handling
@ -375,7 +375,7 @@ class TestMetadataService:
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Try to update with built-in field name
built_in_field_name = BuiltInField.document_name.value
built_in_field_name = BuiltInField.document_name
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name)
@ -540,11 +540,11 @@ class TestMetadataService:
field_names = [field["name"] for field in result]
field_types = [field["type"] for field in result]
assert BuiltInField.document_name.value in field_names
assert BuiltInField.uploader.value in field_names
assert BuiltInField.upload_date.value in field_names
assert BuiltInField.last_update_date.value in field_names
assert BuiltInField.source.value in field_names
assert BuiltInField.document_name in field_names
assert BuiltInField.uploader in field_names
assert BuiltInField.upload_date in field_names
assert BuiltInField.last_update_date in field_names
assert BuiltInField.source in field_names
# Verify field types
assert "string" in field_types
@ -682,11 +682,11 @@ class TestMetadataService:
# Set document metadata with built-in fields
document.doc_metadata = {
BuiltInField.document_name.value: document.name,
BuiltInField.uploader.value: "test_uploader",
BuiltInField.upload_date.value: 1234567890.0,
BuiltInField.last_update_date.value: 1234567890.0,
BuiltInField.source.value: "test_source",
BuiltInField.document_name: document.name,
BuiltInField.uploader: "test_uploader",
BuiltInField.upload_date: 1234567890.0,
BuiltInField.last_update_date: 1234567890.0,
BuiltInField.source: "test_source",
}
db.session.add(document)
db.session.commit()

View File

@ -96,7 +96,7 @@ class TestWorkflowService:
app.tenant_id = fake.uuid4()
app.name = fake.company()
app.description = fake.text()
app.mode = AppMode.WORKFLOW.value
app.mode = AppMode.WORKFLOW
app.icon_type = "emoji"
app.icon = "🤖"
app.icon_background = "#FFEAD5"
@ -883,7 +883,7 @@ class TestWorkflowService:
# Create chat mode app
app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.CHAT.value
app.mode = AppMode.CHAT
# Create app model config (required for conversion)
from models.model import AppModelConfig
@ -926,7 +926,7 @@ class TestWorkflowService:
# Assert
assert result is not None
assert result.mode == AppMode.ADVANCED_CHAT.value # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW
assert result.mode == AppMode.ADVANCED_CHAT # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW
assert result.name == conversion_args["name"]
assert result.icon == conversion_args["icon"]
assert result.icon_type == conversion_args["icon_type"]
@ -945,7 +945,7 @@ class TestWorkflowService:
# Create completion mode app
app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.COMPLETION.value
app.mode = AppMode.COMPLETION
# Create app model config (required for conversion)
from models.model import AppModelConfig
@ -988,7 +988,7 @@ class TestWorkflowService:
# Assert
assert result is not None
assert result.mode == AppMode.WORKFLOW.value
assert result.mode == AppMode.WORKFLOW
assert result.name == conversion_args["name"]
assert result.icon == conversion_args["icon"]
assert result.icon_type == conversion_args["icon_type"]
@ -1007,7 +1007,7 @@ class TestWorkflowService:
# Create workflow mode app (already in workflow mode)
app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.WORKFLOW.value
app.mode = AppMode.WORKFLOW
from extensions.ext_database import db
@ -1030,7 +1030,7 @@ class TestWorkflowService:
# Arrange
fake = Faker()
app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.ADVANCED_CHAT.value
app.mode = AppMode.ADVANCED_CHAT
from extensions.ext_database import db
@ -1061,7 +1061,7 @@ class TestWorkflowService:
# Arrange
fake = Faker()
app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.WORKFLOW.value
app.mode = AppMode.WORKFLOW
from extensions.ext_database import db

View File

@ -0,0 +1,583 @@
"""
TestContainers-based integration tests for delete_segment_from_index_task.
This module provides comprehensive integration testing for the delete_segment_from_index_task
using TestContainers to ensure realistic database interactions and proper isolation.
The task is responsible for removing document segments from the vector index when segments
are deleted from the dataset.
"""
import logging
from unittest.mock import MagicMock, patch
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexType
from models import Account, Dataset, Document, DocumentSegment, Tenant
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
logger = logging.getLogger(__name__)
class TestDeleteSegmentFromIndexTask:
"""
Comprehensive integration tests for delete_segment_from_index_task using testcontainers.
This test class covers all major functionality of the delete_segment_from_index_task:
- Successful segment deletion from index
- Dataset not found scenarios
- Document not found scenarios
- Document status validation (disabled, archived, not completed)
- Index processor integration and cleanup
- Exception handling and error scenarios
- Performance and timing verification
All tests use the testcontainers infrastructure to ensure proper database isolation
and realistic testing environment with actual database interactions.
"""
def _create_test_tenant(self, db_session_with_containers, fake=None):
"""
Helper method to create a test tenant with realistic data.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
fake: Faker instance for generating test data
Returns:
Tenant: Created test tenant instance
"""
fake = fake or Faker()
tenant = Tenant()
tenant.id = fake.uuid4()
tenant.name = f"Test Tenant {fake.company()}"
tenant.plan = "basic"
tenant.status = "active"
tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
return tenant
def _create_test_account(self, db_session_with_containers, tenant, fake=None):
"""
Helper method to create a test account with realistic data.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
tenant: Tenant instance for the account
fake: Faker instance for generating test data
Returns:
Account: Created test account instance
"""
fake = fake or Faker()
account = Account()
account.id = fake.uuid4()
account.email = fake.email()
account.name = fake.name()
account.avatar_url = fake.url()
account.tenant_id = tenant.id
account.status = "active"
account.type = "normal"
account.role = "owner"
account.interface_language = "en-US"
account.created_at = fake.date_time_this_year()
account.updated_at = account.created_at
db_session_with_containers.add(account)
db_session_with_containers.commit()
return account
def _create_test_dataset(self, db_session_with_containers, tenant, account, fake=None):
"""
Helper method to create a test dataset with realistic data.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
tenant: Tenant instance for the dataset
account: Account instance for the dataset
fake: Faker instance for generating test data
Returns:
Dataset: Created test dataset instance
"""
fake = fake or Faker()
dataset = Dataset()
dataset.id = fake.uuid4()
dataset.tenant_id = tenant.id
dataset.name = f"Test Dataset {fake.word()}"
dataset.description = fake.text(max_nb_chars=200)
dataset.provider = "vendor"
dataset.permission = "only_me"
dataset.data_source_type = "upload_file"
dataset.indexing_technique = "high_quality"
dataset.index_struct = '{"type": "paragraph"}'
dataset.created_by = account.id
dataset.created_at = fake.date_time_this_year()
dataset.updated_by = account.id
dataset.updated_at = dataset.created_at
dataset.embedding_model = "text-embedding-ada-002"
dataset.embedding_model_provider = "openai"
dataset.built_in_field_enabled = False
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
def _create_test_document(self, db_session_with_containers, dataset, account, fake=None, **kwargs):
"""
Helper method to create a test document with realistic data.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
dataset: Dataset instance for the document
account: Account instance for the document
fake: Faker instance for generating test data
**kwargs: Additional document attributes to override defaults
Returns:
Document: Created test document instance
"""
fake = fake or Faker()
document = Document()
document.id = fake.uuid4()
document.tenant_id = dataset.tenant_id
document.dataset_id = dataset.id
document.position = kwargs.get("position", 1)
document.data_source_type = kwargs.get("data_source_type", "upload_file")
document.data_source_info = kwargs.get("data_source_info", "{}")
document.batch = kwargs.get("batch", fake.uuid4())
document.name = kwargs.get("name", f"Test Document {fake.word()}")
document.created_from = kwargs.get("created_from", "api")
document.created_by = account.id
document.created_at = fake.date_time_this_year()
document.processing_started_at = kwargs.get("processing_started_at", fake.date_time_this_year())
document.file_id = kwargs.get("file_id", fake.uuid4())
document.word_count = kwargs.get("word_count", fake.random_int(min=100, max=1000))
document.parsing_completed_at = kwargs.get("parsing_completed_at", fake.date_time_this_year())
document.cleaning_completed_at = kwargs.get("cleaning_completed_at", fake.date_time_this_year())
document.splitting_completed_at = kwargs.get("splitting_completed_at", fake.date_time_this_year())
document.tokens = kwargs.get("tokens", fake.random_int(min=50, max=500))
document.indexing_latency = kwargs.get("indexing_latency", fake.random_number(digits=3))
document.completed_at = kwargs.get("completed_at", fake.date_time_this_year())
document.is_paused = kwargs.get("is_paused", False)
document.indexing_status = kwargs.get("indexing_status", "completed")
document.enabled = kwargs.get("enabled", True)
document.archived = kwargs.get("archived", False)
document.updated_at = fake.date_time_this_year()
document.doc_type = kwargs.get("doc_type", "text")
document.doc_metadata = kwargs.get("doc_metadata", {})
document.doc_form = kwargs.get("doc_form", IndexType.PARAGRAPH_INDEX)
document.doc_language = kwargs.get("doc_language", "en")
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
def _create_test_document_segments(self, db_session_with_containers, document, account, count=3, fake=None):
"""
Helper method to create test document segments with realistic data.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
document: Document instance for the segments
account: Account instance for the segments
count: Number of segments to create
fake: Faker instance for generating test data
Returns:
list[DocumentSegment]: List of created test document segment instances
"""
fake = fake or Faker()
segments = []
for i in range(count):
segment = DocumentSegment()
segment.id = fake.uuid4()
segment.tenant_id = document.tenant_id
segment.dataset_id = document.dataset_id
segment.document_id = document.id
segment.position = i + 1
segment.content = f"Test segment content {i + 1}: {fake.text(max_nb_chars=200)}"
segment.answer = f"Test segment answer {i + 1}: {fake.text(max_nb_chars=100)}"
segment.word_count = fake.random_int(min=10, max=100)
segment.tokens = fake.random_int(min=5, max=50)
segment.keywords = [fake.word() for _ in range(3)]
segment.index_node_id = f"node_{fake.uuid4()}"
segment.index_node_hash = fake.sha256()
segment.hit_count = 0
segment.enabled = True
segment.status = "completed"
segment.created_by = account.id
segment.created_at = fake.date_time_this_year()
segment.updated_by = account.id
segment.updated_at = segment.created_at
db_session_with_containers.add(segment)
segments.append(segment)
db_session_with_containers.commit()
return segments
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
def test_delete_segment_from_index_task_success(self, mock_index_processor_factory, db_session_with_containers):
"""
Test successful segment deletion from index with comprehensive verification.
This test verifies:
- Proper task execution with valid dataset and document
- Index processor factory initialization with correct document form
- Index processor clean method called with correct parameters
- Database session properly closed after execution
- Task completes without exceptions
"""
fake = Faker()
# Create test data
tenant = self._create_test_tenant(db_session_with_containers, fake)
account = self._create_test_account(db_session_with_containers, tenant, fake)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
# Extract index node IDs for the task
index_node_ids = [segment.index_node_id for segment in segments]
# Mock the index processor
mock_processor = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
# Verify the task completed successfully
assert result is None # Task should return None on success
# Verify index processor factory was called with correct document form
mock_index_processor_factory.assert_called_once_with(document.doc_form)
# Verify index processor clean method was called with correct parameters
# Note: We can't directly compare Dataset objects as they are different instances
# from database queries, so we verify the call was made and check the parameters
assert mock_processor.clean.call_count == 1
call_args = mock_processor.clean.call_args
assert call_args[0][0].id == dataset.id # Verify dataset ID matches
assert call_args[0][1] == index_node_ids # Verify index node IDs match
assert call_args[1]["with_keywords"] is True
assert call_args[1]["delete_child_chunks"] is True
def test_delete_segment_from_index_task_dataset_not_found(self, db_session_with_containers):
"""
Test task behavior when dataset is not found.
This test verifies:
- Task handles missing dataset gracefully
- No index processor operations are attempted
- Task returns early without exceptions
- Database session is properly closed
"""
fake = Faker()
non_existent_dataset_id = fake.uuid4()
non_existent_document_id = fake.uuid4()
index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)]
# Execute the task with non-existent dataset
result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id)
# Verify the task completed without exceptions
assert result is None # Task should return None when dataset not found
def test_delete_segment_from_index_task_document_not_found(self, db_session_with_containers):
"""
Test task behavior when document is not found.
This test verifies:
- Task handles missing document gracefully
- No index processor operations are attempted
- Task returns early without exceptions
- Database session is properly closed
"""
fake = Faker()
# Create test data
tenant = self._create_test_tenant(db_session_with_containers, fake)
account = self._create_test_account(db_session_with_containers, tenant, fake)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
non_existent_document_id = fake.uuid4()
index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)]
# Execute the task with non-existent document
result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id)
# Verify the task completed without exceptions
assert result is None # Task should return None when document not found
def test_delete_segment_from_index_task_document_disabled(self, db_session_with_containers):
"""
Test task behavior when document is disabled.
This test verifies:
- Task handles disabled document gracefully
- No index processor operations are attempted
- Task returns early without exceptions
- Database session is properly closed
"""
fake = Faker()
# Create test data with disabled document
tenant = self._create_test_tenant(db_session_with_containers, fake)
account = self._create_test_account(db_session_with_containers, tenant, fake)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake, enabled=False)
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
# Execute the task with disabled document
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
# Verify the task completed without exceptions
assert result is None # Task should return None when document is disabled
def test_delete_segment_from_index_task_document_archived(self, db_session_with_containers):
"""
Test task behavior when document is archived.
This test verifies:
- Task handles archived document gracefully
- No index processor operations are attempted
- Task returns early without exceptions
- Database session is properly closed
"""
fake = Faker()
# Create test data with archived document
tenant = self._create_test_tenant(db_session_with_containers, fake)
account = self._create_test_account(db_session_with_containers, tenant, fake)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake, archived=True)
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
# Execute the task with archived document
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
# Verify the task completed without exceptions
assert result is None # Task should return None when document is archived
def test_delete_segment_from_index_task_document_not_completed(self, db_session_with_containers):
"""
Test task behavior when document indexing is not completed.
This test verifies:
- Task handles incomplete indexing status gracefully
- No index processor operations are attempted
- Task returns early without exceptions
- Database session is properly closed
"""
fake = Faker()
# Create test data with incomplete indexing
tenant = self._create_test_tenant(db_session_with_containers, fake)
account = self._create_test_account(db_session_with_containers, tenant, fake)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
document = self._create_test_document(
db_session_with_containers, dataset, account, fake, indexing_status="indexing"
)
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
# Execute the task with incomplete indexing
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
# Verify the task completed without exceptions
assert result is None # Task should return None when indexing is not completed
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
def test_delete_segment_from_index_task_index_processor_clean(
self, mock_index_processor_factory, db_session_with_containers
):
"""
Test index processor clean method integration with different document forms.
This test verifies:
- Index processor factory creates correct processor for different document forms
- Clean method is called with proper parameters for each document form
- Task handles different index types correctly
- Database session is properly managed
"""
fake = Faker()
# Test different document forms
document_forms = [IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX, IndexType.PARENT_CHILD_INDEX]
for doc_form in document_forms:
# Create test data for each document form
tenant = self._create_test_tenant(db_session_with_containers, fake)
account = self._create_test_account(db_session_with_containers, tenant, fake)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake, doc_form=doc_form)
segments = self._create_test_document_segments(db_session_with_containers, document, account, 2, fake)
index_node_ids = [segment.index_node_id for segment in segments]
# Mock the index processor
mock_processor = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
# Verify the task completed successfully
assert result is None
# Verify index processor factory was called with correct document form
mock_index_processor_factory.assert_called_with(doc_form)
# Verify index processor clean method was called with correct parameters
assert mock_processor.clean.call_count == 1
call_args = mock_processor.clean.call_args
assert call_args[0][0].id == dataset.id # Verify dataset ID matches
assert call_args[0][1] == index_node_ids # Verify index node IDs match
assert call_args[1]["with_keywords"] is True
assert call_args[1]["delete_child_chunks"] is True
# Reset mocks for next iteration
mock_index_processor_factory.reset_mock()
mock_processor.reset_mock()
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
def test_delete_segment_from_index_task_exception_handling(
self, mock_index_processor_factory, db_session_with_containers
):
"""
Test exception handling in the task.
This test verifies:
- Task handles index processor exceptions gracefully
- Database session is properly closed even when exceptions occur
- Task logs exceptions appropriately
- No unhandled exceptions are raised
"""
fake = Faker()
# Create test data
tenant = self._create_test_tenant(db_session_with_containers, fake)
account = self._create_test_account(db_session_with_containers, tenant, fake)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
# Mock the index processor to raise an exception
mock_processor = MagicMock()
mock_processor.clean.side_effect = Exception("Index processor error")
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task - should not raise exception
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
# Verify the task completed without raising exceptions
assert result is None # Task should return None even when exceptions occur
# Verify index processor clean method was called
assert mock_processor.clean.call_count == 1
call_args = mock_processor.clean.call_args
assert call_args[0][0].id == dataset.id # Verify dataset ID matches
assert call_args[0][1] == index_node_ids # Verify index node IDs match
assert call_args[1]["with_keywords"] is True
assert call_args[1]["delete_child_chunks"] is True
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
def test_delete_segment_from_index_task_empty_index_node_ids(
self, mock_index_processor_factory, db_session_with_containers
):
"""
Test task behavior with empty index node IDs list.
This test verifies:
- Task handles empty index node IDs gracefully
- Index processor clean method is called with empty list
- Task completes successfully
- Database session is properly managed
"""
fake = Faker()
# Create test data
tenant = self._create_test_tenant(db_session_with_containers, fake)
account = self._create_test_account(db_session_with_containers, tenant, fake)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
# Use empty index node IDs
index_node_ids = []
# Mock the index processor
mock_processor = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
# Verify the task completed successfully
assert result is None
# Verify index processor clean method was called with empty list
assert mock_processor.clean.call_count == 1
call_args = mock_processor.clean.call_args
assert call_args[0][0].id == dataset.id # Verify dataset ID matches
assert call_args[0][1] == index_node_ids # Verify index node IDs match (empty list)
assert call_args[1]["with_keywords"] is True
assert call_args[1]["delete_child_chunks"] is True
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
def test_delete_segment_from_index_task_large_index_node_ids(
self, mock_index_processor_factory, db_session_with_containers
):
"""
Test task behavior with large number of index node IDs.
This test verifies:
- Task handles large lists of index node IDs efficiently
- Index processor clean method is called with all node IDs
- Task completes successfully with large datasets
- Database session is properly managed
"""
fake = Faker()
# Create test data
tenant = self._create_test_tenant(db_session_with_containers, fake)
account = self._create_test_account(db_session_with_containers, tenant, fake)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
# Create large number of segments
segments = self._create_test_document_segments(db_session_with_containers, document, account, 50, fake)
index_node_ids = [segment.index_node_id for segment in segments]
# Mock the index processor
mock_processor = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
# Verify the task completed successfully
assert result is None
# Verify index processor clean method was called with all node IDs
assert mock_processor.clean.call_count == 1
call_args = mock_processor.clean.call_args
assert call_args[0][0].id == dataset.id # Verify dataset ID matches
assert call_args[0][1] == index_node_ids # Verify index node IDs match
assert call_args[1]["with_keywords"] is True
assert call_args[1]["delete_child_chunks"] is True
# Verify all node IDs were passed
assert len(call_args[0][1]) == 50

View File

@ -9,7 +9,6 @@ from flask_restx import Api
import services.errors.account
from controllers.console.auth.error import AuthenticationFailedError
from controllers.console.auth.login import LoginApi
from controllers.console.error import AccountNotFound
class TestAuthenticationSecurity:
@ -27,31 +26,33 @@ class TestAuthenticationSecurity:
@patch("controllers.console.auth.login.FeatureService.get_system_features")
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.send_reset_password_email")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
def test_login_invalid_email_with_registration_allowed(
self, mock_get_invitation, mock_send_email, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
):
"""Test that invalid email sends reset password email when registration is allowed."""
"""Test that invalid email raises AuthenticationFailedError when account not found."""
# Arrange
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found")
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = True
mock_send_email.return_value = "token123"
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
):
login_api = LoginApi()
result = login_api.post()
# Assert
assert result == {"result": "fail", "data": "token123", "code": "account_not_found"}
mock_send_email.assert_called_once_with(email="nonexistent@example.com", language="en-US")
# Assert
with pytest.raises(AuthenticationFailedError) as exc_info:
login_api.post()
assert exc_info.value.error_code == "authentication_failed"
assert exc_info.value.description == "Invalid email or password."
mock_add_rate_limit.assert_called_once_with("nonexistent@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@ -87,16 +88,17 @@ class TestAuthenticationSecurity:
@patch("controllers.console.auth.login.FeatureService.get_system_features")
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
def test_login_invalid_email_with_registration_disabled(
self, mock_get_invitation, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
):
"""Test that invalid email raises AccountNotFound when registration is disabled."""
"""Test that invalid email raises AuthenticationFailedError when account not found."""
# Arrange
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found")
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = False
@ -107,10 +109,12 @@ class TestAuthenticationSecurity:
login_api = LoginApi()
# Assert
with pytest.raises(AccountNotFound) as exc_info:
with pytest.raises(AuthenticationFailedError) as exc_info:
login_api.post()
assert exc_info.value.error_code == "account_not_found"
assert exc_info.value.error_code == "authentication_failed"
assert exc_info.value.description == "Invalid email or password."
mock_add_rate_limit.assert_called_once_with("nonexistent@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.FeatureService.get_system_features")

View File

@ -12,7 +12,7 @@ from controllers.console.auth.oauth import (
)
from libs.oauth import OAuthUserInfo
from models.account import AccountStatus
from services.errors.account import AccountNotFoundError
from services.errors.account import AccountRegisterError
class TestGetOAuthProviders:
@ -451,7 +451,7 @@ class TestAccountGeneration:
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
if not allow_register and not existing_account:
with pytest.raises(AccountNotFoundError):
with pytest.raises(AccountRegisterError):
_generate_account("github", user_info)
else:
result = _generate_account("github", user_info)

View File

@ -29,7 +29,7 @@ class TestHandleMCPRequest:
"""Setup test fixtures"""
self.app = Mock(spec=App)
self.app.name = "test_app"
self.app.mode = AppMode.CHAT.value
self.app.mode = AppMode.CHAT
self.mcp_server = Mock(spec=AppMCPServer)
self.mcp_server.description = "Test server"
@ -196,7 +196,7 @@ class TestIndividualHandlers:
def test_handle_list_tools(self):
"""Test list tools handler"""
app_name = "test_app"
app_mode = AppMode.CHAT.value
app_mode = AppMode.CHAT
description = "Test server"
parameters_dict: dict[str, str] = {}
user_input_form: list[VariableEntity] = []
@ -212,7 +212,7 @@ class TestIndividualHandlers:
def test_handle_call_tool(self, mock_app_generate):
"""Test call tool handler"""
app = Mock(spec=App)
app.mode = AppMode.CHAT.value
app.mode = AppMode.CHAT
# Create mock request
mock_request = Mock()
@ -252,7 +252,7 @@ class TestUtilityFunctions:
def test_build_parameter_schema_chat_mode(self):
"""Test building parameter schema for chat mode"""
app_mode = AppMode.CHAT.value
app_mode = AppMode.CHAT
parameters_dict: dict[str, str] = {"name": "Enter your name"}
user_input_form = [
@ -275,7 +275,7 @@ class TestUtilityFunctions:
def test_build_parameter_schema_workflow_mode(self):
"""Test building parameter schema for workflow mode"""
app_mode = AppMode.WORKFLOW.value
app_mode = AppMode.WORKFLOW
parameters_dict: dict[str, str] = {"input_text": "Enter text"}
user_input_form = [
@ -298,7 +298,7 @@ class TestUtilityFunctions:
def test_prepare_tool_arguments_chat_mode(self):
"""Test preparing tool arguments for chat mode"""
app = Mock(spec=App)
app.mode = AppMode.CHAT.value
app.mode = AppMode.CHAT
arguments = {"query": "test question", "name": "John"}
@ -312,7 +312,7 @@ class TestUtilityFunctions:
def test_prepare_tool_arguments_workflow_mode(self):
"""Test preparing tool arguments for workflow mode"""
app = Mock(spec=App)
app.mode = AppMode.WORKFLOW.value
app.mode = AppMode.WORKFLOW
arguments = {"input_text": "test input"}
@ -324,7 +324,7 @@ class TestUtilityFunctions:
def test_prepare_tool_arguments_completion_mode(self):
"""Test preparing tool arguments for completion mode"""
app = Mock(spec=App)
app.mode = AppMode.COMPLETION.value
app.mode = AppMode.COMPLETION
arguments = {"name": "John"}
@ -336,7 +336,7 @@ class TestUtilityFunctions:
def test_extract_answer_from_mapping_response_chat(self):
"""Test extracting answer from mapping response for chat mode"""
app = Mock(spec=App)
app.mode = AppMode.CHAT.value
app.mode = AppMode.CHAT
response = {"answer": "test answer", "other": "data"}
@ -347,7 +347,7 @@ class TestUtilityFunctions:
def test_extract_answer_from_mapping_response_workflow(self):
"""Test extracting answer from mapping response for workflow mode"""
app = Mock(spec=App)
app.mode = AppMode.WORKFLOW.value
app.mode = AppMode.WORKFLOW
response = {"data": {"outputs": {"result": "test result"}}}

View File

@ -20,7 +20,6 @@ def test_firecrawl_web_extractor_crawl_mode(mocker):
}
mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl))
job_id = firecrawl_app.crawl_url(url, params)
print(f"job_id: {job_id}")
assert job_id is not None
assert isinstance(job_id, str)

View File

@ -129,7 +129,6 @@ class TestSegmentDumpAndLoad:
"""Test basic segment serialization compatibility"""
model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")])
json = model.model_dump_json()
print("Json: ", json)
loaded = _Segments.model_validate_json(json)
assert loaded == model
@ -137,7 +136,6 @@ class TestSegmentDumpAndLoad:
"""Test number segment serialization compatibility"""
model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)])
json = model.model_dump_json()
print("Json: ", json)
loaded = _Segments.model_validate_json(json)
assert loaded == model
@ -145,7 +143,6 @@ class TestSegmentDumpAndLoad:
"""Test variable serialization compatibility"""
model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")])
json = model.model_dump_json()
print("Json: ", json)
restored = _Variables.model_validate_json(json)
assert restored == model

View File

@ -20,7 +20,7 @@ from pathlib import Path
from typing import Any, Optional
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.utils.yaml_utils import load_yaml_file
from core.tools.utils.yaml_utils import _load_yaml_file
from core.variables import (
ArrayNumberVariable,
ArrayObjectVariable,
@ -713,4 +713,4 @@ def _load_fixture(fixture_path: Path, fixture_name: str) -> dict[str, Any]:
if not fixture_path.exists():
raise FileNotFoundError(f"Fixture file not found: {fixture_path}")
return load_yaml_file(str(fixture_path), ignore_error=False)
return _load_yaml_file(file_path=str(fixture_path))

View File

@ -142,15 +142,11 @@ def test_remove_first_from_array():
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
# Print the variable before running
print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
# Run the node
result = list(node.run())
# Print the variable after running and the result
print(f"After: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
print(f"Result: {result}")
# Completed run
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None

View File

@ -0,0 +1,456 @@
import pytest
from core.file.enums import FileType
from core.file.models import File, FileTransferMethod
from core.variables.variables import StringVariable
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
class TestWorkflowEntry:
"""Test WorkflowEntry class methods."""
def test_mapping_user_inputs_to_variable_pool_with_system_variables(self):
"""Test mapping system variables from user inputs to variable pool."""
# Initialize variable pool with system variables
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="test_user_id",
app_id="test_app_id",
workflow_id="test_workflow_id",
),
user_inputs={},
)
# Define variable mapping - sys variables mapped to other nodes
variable_mapping = {
"node1.input1": ["node1", "input1"], # Regular mapping
"node2.query": ["node2", "query"], # Regular mapping
"sys.user_id": ["output_node", "user"], # System variable mapping
}
# User inputs including sys variables
user_inputs = {
"node1.input1": "new_user_id",
"node2.query": "test query",
"sys.user_id": "system_user",
}
# Execute mapping
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id="test_tenant",
)
# Verify variables were added to pool
# Note: variable_pool.get returns Variable objects, not raw values
node1_var = variable_pool.get(["node1", "input1"])
assert node1_var is not None
assert node1_var.value == "new_user_id"
node2_var = variable_pool.get(["node2", "query"])
assert node2_var is not None
assert node2_var.value == "test query"
# System variable gets mapped to output node
output_var = variable_pool.get(["output_node", "user"])
assert output_var is not None
assert output_var.value == "system_user"
def test_mapping_user_inputs_to_variable_pool_with_env_variables(self):
"""Test mapping environment variables from user inputs to variable pool."""
# Initialize variable pool with environment variables
env_var = StringVariable(name="API_KEY", value="existing_key")
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
environment_variables=[env_var],
user_inputs={},
)
# Add env variable to pool (simulating initialization)
variable_pool.add([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"], env_var)
# Define variable mapping - env variables should not be overridden
variable_mapping = {
"node1.api_key": [ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"],
"node2.new_env": [ENVIRONMENT_VARIABLE_NODE_ID, "NEW_ENV"],
}
# User inputs
user_inputs = {
"node1.api_key": "user_provided_key", # This should not override existing env var
"node2.new_env": "new_env_value",
}
# Execute mapping
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id="test_tenant",
)
# Verify env variable was not overridden
env_value = variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"])
assert env_value is not None
assert env_value.value == "existing_key" # Should remain unchanged
# New env variables from user input should not be added
assert variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "NEW_ENV"]) is None
def test_mapping_user_inputs_to_variable_pool_with_conversation_variables(self):
"""Test mapping conversation variables from user inputs to variable pool."""
# Initialize variable pool with conversation variables
conv_var = StringVariable(name="last_message", value="Hello")
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
conversation_variables=[conv_var],
user_inputs={},
)
# Add conversation variable to pool
variable_pool.add([CONVERSATION_VARIABLE_NODE_ID, "last_message"], conv_var)
# Define variable mapping
variable_mapping = {
"node1.message": ["node1", "message"], # Map to regular node
"conversation.context": ["chat_node", "context"], # Conversation var to regular node
}
# User inputs
user_inputs = {
"node1.message": "Updated message",
"conversation.context": "New context",
}
# Execute mapping
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id="test_tenant",
)
# Verify variables were added to their target nodes
node1_var = variable_pool.get(["node1", "message"])
assert node1_var is not None
assert node1_var.value == "Updated message"
chat_var = variable_pool.get(["chat_node", "context"])
assert chat_var is not None
assert chat_var.value == "New context"
def test_mapping_user_inputs_to_variable_pool_with_regular_variables(self):
"""Test mapping regular node variables from user inputs to variable pool."""
# Initialize empty variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
)
# Define variable mapping for regular nodes
variable_mapping = {
"input_node.text": ["input_node", "text"],
"llm_node.prompt": ["llm_node", "prompt"],
"code_node.input": ["code_node", "input"],
}
# User inputs
user_inputs = {
"input_node.text": "User input text",
"llm_node.prompt": "Generate a summary",
"code_node.input": {"key": "value"},
}
# Execute mapping
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id="test_tenant",
)
# Verify regular variables were added
text_var = variable_pool.get(["input_node", "text"])
assert text_var is not None
assert text_var.value == "User input text"
prompt_var = variable_pool.get(["llm_node", "prompt"])
assert prompt_var is not None
assert prompt_var.value == "Generate a summary"
input_var = variable_pool.get(["code_node", "input"])
assert input_var is not None
assert input_var.value == {"key": "value"}
def test_mapping_user_inputs_with_file_handling(self):
"""Test mapping file inputs from user inputs to variable pool."""
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
)
# Define variable mapping
variable_mapping = {
"file_node.file": ["file_node", "file"],
"file_node.files": ["file_node", "files"],
}
# User inputs with file data - using remote_url which doesn't require upload_file_id
user_inputs = {
"file_node.file": {
"type": "document",
"transfer_method": "remote_url",
"url": "http://example.com/test.pdf",
},
"file_node.files": [
{
"type": "image",
"transfer_method": "remote_url",
"url": "http://example.com/image1.jpg",
},
{
"type": "image",
"transfer_method": "remote_url",
"url": "http://example.com/image2.jpg",
},
],
}
# Execute mapping
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id="test_tenant",
)
# Verify file was converted and added
file_var = variable_pool.get(["file_node", "file"])
assert file_var is not None
assert file_var.value.type == FileType.DOCUMENT
assert file_var.value.transfer_method == FileTransferMethod.REMOTE_URL
# Verify file list was converted and added
files_var = variable_pool.get(["file_node", "files"])
assert files_var is not None
assert isinstance(files_var.value, list)
assert len(files_var.value) == 2
assert all(isinstance(f, File) for f in files_var.value)
assert files_var.value[0].type == FileType.IMAGE
assert files_var.value[1].type == FileType.IMAGE
assert files_var.value[0].type == FileType.IMAGE
assert files_var.value[1].type == FileType.IMAGE
def test_mapping_user_inputs_missing_variable_error(self):
"""Test that mapping raises error when required variable is missing."""
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
)
# Define variable mapping
variable_mapping = {
"node1.required_input": ["node1", "required_input"],
}
# User inputs without required variable
user_inputs = {
"node1.other_input": "some value",
}
# Should raise ValueError for missing variable
with pytest.raises(ValueError, match="Variable key node1.required_input not found in user inputs"):
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id="test_tenant",
)
def test_mapping_user_inputs_with_alternative_key_format(self):
"""Test mapping with alternative key format (without node prefix)."""
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
)
# Define variable mapping
variable_mapping = {
"node1.input": ["node1", "input"],
}
# User inputs with alternative key format
user_inputs = {
"input": "value without node prefix", # Alternative format without node prefix
}
# Execute mapping
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id="test_tenant",
)
# Verify variable was added using alternative key
input_var = variable_pool.get(["node1", "input"])
assert input_var is not None
assert input_var.value == "value without node prefix"
def test_mapping_user_inputs_with_complex_selectors(self):
"""Test mapping with complex node variable keys."""
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
)
# Define variable mapping - selectors can only have 2 elements
variable_mapping = {
"node1.data.field1": ["node1", "data_field1"], # Complex key mapped to simple selector
"node2.config.settings.timeout": ["node2", "timeout"], # Complex key mapped to simple selector
}
# User inputs
user_inputs = {
"node1.data.field1": "nested value",
"node2.config.settings.timeout": 30,
}
# Execute mapping
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id="test_tenant",
)
# Verify variables were added with simple selectors
data_var = variable_pool.get(["node1", "data_field1"])
assert data_var is not None
assert data_var.value == "nested value"
timeout_var = variable_pool.get(["node2", "timeout"])
assert timeout_var is not None
assert timeout_var.value == 30
def test_mapping_user_inputs_invalid_node_variable(self):
"""Test that mapping handles invalid node variable format."""
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
)
# Define variable mapping with single element node variable (at least one dot is required)
variable_mapping = {
"singleelement": ["node1", "input"], # No dot separator
}
user_inputs = {"singleelement": "some value"} # Must use exact key
# Should NOT raise error - function accepts it and uses direct key
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id="test_tenant",
)
# Verify it was added
var = variable_pool.get(["node1", "input"])
assert var is not None
assert var.value == "some value"
def test_mapping_all_variable_types_together(self):
"""Test mapping all four types of variables in one operation."""
# Initialize variable pool with some existing variables
env_var = StringVariable(name="API_KEY", value="existing_key")
conv_var = StringVariable(name="session_id", value="session123")
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="test_user",
app_id="test_app",
query="initial query",
),
environment_variables=[env_var],
conversation_variables=[conv_var],
user_inputs={},
)
# Add existing variables to pool
variable_pool.add([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"], env_var)
variable_pool.add([CONVERSATION_VARIABLE_NODE_ID, "session_id"], conv_var)
# Define comprehensive variable mapping
variable_mapping = {
# System variables mapped to regular nodes
"sys.user_id": ["start", "user"],
"sys.app_id": ["start", "app"],
# Environment variables (won't be overridden)
"env.API_KEY": ["config", "api_key"],
# Conversation variables mapped to regular nodes
"conversation.session_id": ["chat", "session"],
# Regular variables
"input.text": ["input", "text"],
"process.data": ["process", "data"],
}
# User inputs
user_inputs = {
"sys.user_id": "new_user",
"sys.app_id": "new_app",
"env.API_KEY": "attempted_override", # Should not override env var
"conversation.session_id": "new_session",
"input.text": "user input text",
"process.data": {"value": 123, "status": "active"},
}
# Execute mapping
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id="test_tenant",
)
# Verify system variables were added to their target nodes
start_user = variable_pool.get(["start", "user"])
assert start_user is not None
assert start_user.value == "new_user"
start_app = variable_pool.get(["start", "app"])
assert start_app is not None
assert start_app.value == "new_app"
# Verify env variable was not overridden (still has original value)
env_value = variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"])
assert env_value is not None
assert env_value.value == "existing_key"
# Environment variables get mapped to other nodes even when they exist in env pool
# But the original env value remains unchanged
config_api_key = variable_pool.get(["config", "api_key"])
assert config_api_key is not None
assert config_api_key.value == "attempted_override"
# Verify conversation variable was mapped to target node
chat_session = variable_pool.get(["chat", "session"])
assert chat_session is not None
assert chat_session.value == "new_session"
# Verify regular variables were added
input_text = variable_pool.get(["input", "text"])
assert input_text is not None
assert input_text.value == "user input text"
process_data = variable_pool.get(["process", "data"])
assert process_data is not None
assert process_data.value == {"value": 123, "status": "active"}

View File

@ -11,12 +11,12 @@ class TestSupabaseStorage:
def test_init_success_with_all_config(self):
"""Test successful initialization when all required config is provided."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
mock_client = Mock()
mock_client_class.return_value = mock_client
@ -31,7 +31,7 @@ class TestSupabaseStorage:
def test_init_raises_error_when_url_missing(self):
"""Test initialization raises ValueError when SUPABASE_URL is None."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = None
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
@ -41,7 +41,7 @@ class TestSupabaseStorage:
def test_init_raises_error_when_api_key_missing(self):
"""Test initialization raises ValueError when SUPABASE_API_KEY is None."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = None
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
@ -51,7 +51,7 @@ class TestSupabaseStorage:
def test_init_raises_error_when_bucket_name_missing(self):
"""Test initialization raises ValueError when SUPABASE_BUCKET_NAME is None."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = None
@ -61,12 +61,12 @@ class TestSupabaseStorage:
def test_create_bucket_when_not_exists(self):
"""Test create_bucket creates bucket when it doesn't exist."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
mock_client = Mock()
mock_client_class.return_value = mock_client
@ -77,12 +77,12 @@ class TestSupabaseStorage:
def test_create_bucket_when_exists(self):
"""Test create_bucket does not create bucket when it already exists."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
mock_client = Mock()
mock_client_class.return_value = mock_client
@ -94,12 +94,12 @@ class TestSupabaseStorage:
@pytest.fixture
def storage_with_mock_client(self):
"""Fixture providing SupabaseStorage with mocked client."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
mock_client = Mock()
mock_client_class.return_value = mock_client
@ -251,12 +251,12 @@ class TestSupabaseStorage:
def test_bucket_exists_returns_true_when_bucket_found(self):
"""Test bucket_exists returns True when bucket is found in list."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
mock_client = Mock()
mock_client_class.return_value = mock_client
@ -271,12 +271,12 @@ class TestSupabaseStorage:
def test_bucket_exists_returns_false_when_bucket_not_found(self):
"""Test bucket_exists returns False when bucket is not found in list."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
mock_client = Mock()
mock_client_class.return_value = mock_client
@ -294,12 +294,12 @@ class TestSupabaseStorage:
def test_bucket_exists_returns_false_when_no_buckets(self):
"""Test bucket_exists returns False when no buckets exist."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
mock_client = Mock()
mock_client_class.return_value = mock_client

View File

@ -10,7 +10,6 @@ from services.account_service import AccountService, RegisterService, TenantServ
from services.errors.account import (
AccountAlreadyInTenantError,
AccountLoginError,
AccountNotFoundError,
AccountPasswordError,
AccountRegisterError,
CurrentPasswordIncorrectError,
@ -195,7 +194,7 @@ class TestAccountService:
# Execute test and verify exception
self._assert_exception_raised(
AccountNotFoundError, AccountService.authenticate, "notfound@example.com", "password"
AccountPasswordError, AccountService.authenticate, "notfound@example.com", "password"
)
def test_authenticate_account_banned(self, mock_db_dependencies):

View File

@ -66,7 +66,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
app_model = MagicMock()
app_model.id = "app_id"
app_model.tenant_id = "tenant_id"
app_model.mode = AppMode.CHAT.value
app_model.mode = AppMode.CHAT
api_based_extension_id = "api_based_extension_id"
mock_api_based_extension = APIBasedExtension(
@ -127,7 +127,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
app_model = MagicMock()
app_model.id = "app_id"
app_model.tenant_id = "tenant_id"
app_model.mode = AppMode.WORKFLOW.value
app_model.mode = AppMode.WORKFLOW
api_based_extension_id = "api_based_extension_id"
mock_api_based_extension = APIBasedExtension(

View File

@ -279,8 +279,6 @@ def test_structured_output_parser():
]
for case in testcases:
print(f"Running test case: {case['name']}")
# Setup model entity
model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"])

View File

@ -3,7 +3,7 @@ from textwrap import dedent
import pytest
from yaml import YAMLError
from core.tools.utils.yaml_utils import load_yaml_file
from core.tools.utils.yaml_utils import _load_yaml_file
EXAMPLE_YAML_FILE = "example_yaml.yaml"
INVALID_YAML_FILE = "invalid_yaml.yaml"
@ -56,15 +56,15 @@ def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
def test_load_yaml_non_existing_file():
assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
assert load_yaml_file(file_path="") == {}
with pytest.raises(FileNotFoundError):
_load_yaml_file(file_path=NON_EXISTING_YAML_FILE)
with pytest.raises(FileNotFoundError):
load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False)
_load_yaml_file(file_path="")
def test_load_valid_yaml_file(prepare_example_yaml_file):
yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
yaml_data = _load_yaml_file(file_path=prepare_example_yaml_file)
assert len(yaml_data) > 0
assert yaml_data["age"] == 30
assert yaml_data["gender"] == "male"
@ -77,7 +77,4 @@ def test_load_valid_yaml_file(prepare_example_yaml_file):
def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
# yaml syntax error
with pytest.raises(YAMLError):
load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=False)
# ignore error
assert load_yaml_file(file_path=prepare_invalid_yaml_file) == {}
_load_yaml_file(file_path=prepare_invalid_yaml_file)