mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@ -203,6 +203,7 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
# Reset password token expiry minutes
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
|
||||
EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5
|
||||
CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5
|
||||
OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@ from services.account_service import AccountService, RegisterService, TenantServ
|
||||
from services.errors.account import (
|
||||
AccountAlreadyInTenantError,
|
||||
AccountLoginError,
|
||||
AccountNotFoundError,
|
||||
AccountPasswordError,
|
||||
AccountRegisterError,
|
||||
CurrentPasswordIncorrectError,
|
||||
@ -161,7 +160,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
password = fake.password(length=12)
|
||||
with pytest.raises(AccountNotFoundError):
|
||||
with pytest.raises(AccountPasswordError):
|
||||
AccountService.authenticate(email, password)
|
||||
|
||||
def test_authenticate_banned_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
|
||||
@ -42,7 +42,7 @@ class TestAdvancedPromptTemplateService:
|
||||
|
||||
# Test data for Baichuan model
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
@ -77,7 +77,7 @@ class TestAdvancedPromptTemplateService:
|
||||
|
||||
# Test data for common model
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
@ -116,7 +116,7 @@ class TestAdvancedPromptTemplateService:
|
||||
|
||||
for model_name in test_cases:
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": model_name,
|
||||
"has_context": "true",
|
||||
@ -144,7 +144,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "true")
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -173,7 +173,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "chat", "true")
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -202,7 +202,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "completion", "true")
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -230,7 +230,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "chat", "true")
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -257,7 +257,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "false")
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "false")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -303,7 +303,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "unsupported_mode", "true")
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "unsupported_mode", "true")
|
||||
|
||||
# Assert: Verify empty dict is returned
|
||||
assert result == {}
|
||||
@ -442,7 +442,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "true")
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -473,7 +473,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "chat", "true")
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -502,7 +502,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "completion", "true")
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -530,7 +530,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "chat", "true")
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -557,7 +557,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "false")
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -603,7 +603,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "unsupported_mode", "true")
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "unsupported_mode", "true")
|
||||
|
||||
# Assert: Verify empty dict is returned
|
||||
assert result == {}
|
||||
@ -621,7 +621,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Test all app modes
|
||||
app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value]
|
||||
app_modes = [AppMode.CHAT, AppMode.COMPLETION]
|
||||
model_modes = ["completion", "chat"]
|
||||
|
||||
for app_mode in app_modes:
|
||||
@ -653,7 +653,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Test all app modes
|
||||
app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value]
|
||||
app_modes = [AppMode.CHAT, AppMode.COMPLETION]
|
||||
model_modes = ["completion", "chat"]
|
||||
|
||||
for app_mode in app_modes:
|
||||
@ -686,10 +686,10 @@ class TestAdvancedPromptTemplateService:
|
||||
# Test edge cases
|
||||
edge_cases = [
|
||||
{"app_mode": "", "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true"},
|
||||
{"app_mode": AppMode.CHAT.value, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"},
|
||||
{"app_mode": AppMode.CHAT.value, "model_mode": "completion", "model_name": "", "has_context": "true"},
|
||||
{"app_mode": AppMode.CHAT, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"},
|
||||
{"app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "", "has_context": "true"},
|
||||
{
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "",
|
||||
@ -723,7 +723,7 @@ class TestAdvancedPromptTemplateService:
|
||||
|
||||
# Test with context
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
@ -757,7 +757,7 @@ class TestAdvancedPromptTemplateService:
|
||||
|
||||
# Test with context
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
@ -786,25 +786,25 @@ class TestAdvancedPromptTemplateService:
|
||||
# Test different scenarios
|
||||
test_scenarios = [
|
||||
{
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "chat",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION.value,
|
||||
"app_mode": AppMode.COMPLETION,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION.value,
|
||||
"app_mode": AppMode.COMPLETION,
|
||||
"model_mode": "chat",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
@ -843,25 +843,25 @@ class TestAdvancedPromptTemplateService:
|
||||
# Test different scenarios
|
||||
test_scenarios = [
|
||||
{
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "chat",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION.value,
|
||||
"app_mode": AppMode.COMPLETION,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION.value,
|
||||
"app_mode": AppMode.COMPLETION,
|
||||
"model_mode": "chat",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
|
||||
@ -255,7 +255,7 @@ class TestMetadataService:
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Try to create metadata with built-in field name
|
||||
built_in_field_name = BuiltInField.document_name.value
|
||||
built_in_field_name = BuiltInField.document_name
|
||||
metadata_args = MetadataArgs(type="string", name=built_in_field_name)
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
@ -375,7 +375,7 @@ class TestMetadataService:
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
|
||||
# Try to update with built-in field name
|
||||
built_in_field_name = BuiltInField.document_name.value
|
||||
built_in_field_name = BuiltInField.document_name
|
||||
|
||||
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
|
||||
MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name)
|
||||
@ -540,11 +540,11 @@ class TestMetadataService:
|
||||
field_names = [field["name"] for field in result]
|
||||
field_types = [field["type"] for field in result]
|
||||
|
||||
assert BuiltInField.document_name.value in field_names
|
||||
assert BuiltInField.uploader.value in field_names
|
||||
assert BuiltInField.upload_date.value in field_names
|
||||
assert BuiltInField.last_update_date.value in field_names
|
||||
assert BuiltInField.source.value in field_names
|
||||
assert BuiltInField.document_name in field_names
|
||||
assert BuiltInField.uploader in field_names
|
||||
assert BuiltInField.upload_date in field_names
|
||||
assert BuiltInField.last_update_date in field_names
|
||||
assert BuiltInField.source in field_names
|
||||
|
||||
# Verify field types
|
||||
assert "string" in field_types
|
||||
@ -682,11 +682,11 @@ class TestMetadataService:
|
||||
|
||||
# Set document metadata with built-in fields
|
||||
document.doc_metadata = {
|
||||
BuiltInField.document_name.value: document.name,
|
||||
BuiltInField.uploader.value: "test_uploader",
|
||||
BuiltInField.upload_date.value: 1234567890.0,
|
||||
BuiltInField.last_update_date.value: 1234567890.0,
|
||||
BuiltInField.source.value: "test_source",
|
||||
BuiltInField.document_name: document.name,
|
||||
BuiltInField.uploader: "test_uploader",
|
||||
BuiltInField.upload_date: 1234567890.0,
|
||||
BuiltInField.last_update_date: 1234567890.0,
|
||||
BuiltInField.source: "test_source",
|
||||
}
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
@ -96,7 +96,7 @@ class TestWorkflowService:
|
||||
app.tenant_id = fake.uuid4()
|
||||
app.name = fake.company()
|
||||
app.description = fake.text()
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
app.mode = AppMode.WORKFLOW
|
||||
app.icon_type = "emoji"
|
||||
app.icon = "🤖"
|
||||
app.icon_background = "#FFEAD5"
|
||||
@ -883,7 +883,7 @@ class TestWorkflowService:
|
||||
|
||||
# Create chat mode app
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.mode = AppMode.CHAT
|
||||
|
||||
# Create app model config (required for conversion)
|
||||
from models.model import AppModelConfig
|
||||
@ -926,7 +926,7 @@ class TestWorkflowService:
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.mode == AppMode.ADVANCED_CHAT.value # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW
|
||||
assert result.mode == AppMode.ADVANCED_CHAT # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW
|
||||
assert result.name == conversion_args["name"]
|
||||
assert result.icon == conversion_args["icon"]
|
||||
assert result.icon_type == conversion_args["icon_type"]
|
||||
@ -945,7 +945,7 @@ class TestWorkflowService:
|
||||
|
||||
# Create completion mode app
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = AppMode.COMPLETION.value
|
||||
app.mode = AppMode.COMPLETION
|
||||
|
||||
# Create app model config (required for conversion)
|
||||
from models.model import AppModelConfig
|
||||
@ -988,7 +988,7 @@ class TestWorkflowService:
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.mode == AppMode.WORKFLOW.value
|
||||
assert result.mode == AppMode.WORKFLOW
|
||||
assert result.name == conversion_args["name"]
|
||||
assert result.icon == conversion_args["icon"]
|
||||
assert result.icon_type == conversion_args["icon_type"]
|
||||
@ -1007,7 +1007,7 @@ class TestWorkflowService:
|
||||
|
||||
# Create workflow mode app (already in workflow mode)
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
app.mode = AppMode.WORKFLOW
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
@ -1030,7 +1030,7 @@ class TestWorkflowService:
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = AppMode.ADVANCED_CHAT.value
|
||||
app.mode = AppMode.ADVANCED_CHAT
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
@ -1061,7 +1061,7 @@ class TestWorkflowService:
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
app.mode = AppMode.WORKFLOW
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
@ -0,0 +1,583 @@
|
||||
"""
|
||||
TestContainers-based integration tests for delete_segment_from_index_task.
|
||||
|
||||
This module provides comprehensive integration testing for the delete_segment_from_index_task
|
||||
using TestContainers to ensure realistic database interactions and proper isolation.
|
||||
The task is responsible for removing document segments from the vector index when segments
|
||||
are deleted from the dataset.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from models import Account, Dataset, Document, DocumentSegment, Tenant
|
||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestDeleteSegmentFromIndexTask:
|
||||
"""
|
||||
Comprehensive integration tests for delete_segment_from_index_task using testcontainers.
|
||||
|
||||
This test class covers all major functionality of the delete_segment_from_index_task:
|
||||
- Successful segment deletion from index
|
||||
- Dataset not found scenarios
|
||||
- Document not found scenarios
|
||||
- Document status validation (disabled, archived, not completed)
|
||||
- Index processor integration and cleanup
|
||||
- Exception handling and error scenarios
|
||||
- Performance and timing verification
|
||||
|
||||
All tests use the testcontainers infrastructure to ensure proper database isolation
|
||||
and realistic testing environment with actual database interactions.
|
||||
"""
|
||||
|
||||
def _create_test_tenant(self, db_session_with_containers, fake=None):
|
||||
"""
|
||||
Helper method to create a test tenant with realistic data.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
fake: Faker instance for generating test data
|
||||
|
||||
Returns:
|
||||
Tenant: Created test tenant instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
tenant = Tenant()
|
||||
tenant.id = fake.uuid4()
|
||||
tenant.name = f"Test Tenant {fake.company()}"
|
||||
tenant.plan = "basic"
|
||||
tenant.status = "active"
|
||||
tenant.created_at = fake.date_time_this_year()
|
||||
tenant.updated_at = tenant.created_at
|
||||
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
return tenant
|
||||
|
||||
def _create_test_account(self, db_session_with_containers, tenant, fake=None):
|
||||
"""
|
||||
Helper method to create a test account with realistic data.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
tenant: Tenant instance for the account
|
||||
fake: Faker instance for generating test data
|
||||
|
||||
Returns:
|
||||
Account: Created test account instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
account = Account()
|
||||
account.id = fake.uuid4()
|
||||
account.email = fake.email()
|
||||
account.name = fake.name()
|
||||
account.avatar_url = fake.url()
|
||||
account.tenant_id = tenant.id
|
||||
account.status = "active"
|
||||
account.type = "normal"
|
||||
account.role = "owner"
|
||||
account.interface_language = "en-US"
|
||||
account.created_at = fake.date_time_this_year()
|
||||
account.updated_at = account.created_at
|
||||
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
return account
|
||||
|
||||
def _create_test_dataset(self, db_session_with_containers, tenant, account, fake=None):
|
||||
"""
|
||||
Helper method to create a test dataset with realistic data.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
tenant: Tenant instance for the dataset
|
||||
account: Account instance for the dataset
|
||||
fake: Faker instance for generating test data
|
||||
|
||||
Returns:
|
||||
Dataset: Created test dataset instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
dataset = Dataset()
|
||||
dataset.id = fake.uuid4()
|
||||
dataset.tenant_id = tenant.id
|
||||
dataset.name = f"Test Dataset {fake.word()}"
|
||||
dataset.description = fake.text(max_nb_chars=200)
|
||||
dataset.provider = "vendor"
|
||||
dataset.permission = "only_me"
|
||||
dataset.data_source_type = "upload_file"
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.index_struct = '{"type": "paragraph"}'
|
||||
dataset.created_by = account.id
|
||||
dataset.created_at = fake.date_time_this_year()
|
||||
dataset.updated_by = account.id
|
||||
dataset.updated_at = dataset.created_at
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.built_in_field_enabled = False
|
||||
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
return dataset
|
||||
|
||||
def _create_test_document(self, db_session_with_containers, dataset, account, fake=None, **kwargs):
|
||||
"""
|
||||
Helper method to create a test document with realistic data.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
dataset: Dataset instance for the document
|
||||
account: Account instance for the document
|
||||
fake: Faker instance for generating test data
|
||||
**kwargs: Additional document attributes to override defaults
|
||||
|
||||
Returns:
|
||||
Document: Created test document instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
document = Document()
|
||||
document.id = fake.uuid4()
|
||||
document.tenant_id = dataset.tenant_id
|
||||
document.dataset_id = dataset.id
|
||||
document.position = kwargs.get("position", 1)
|
||||
document.data_source_type = kwargs.get("data_source_type", "upload_file")
|
||||
document.data_source_info = kwargs.get("data_source_info", "{}")
|
||||
document.batch = kwargs.get("batch", fake.uuid4())
|
||||
document.name = kwargs.get("name", f"Test Document {fake.word()}")
|
||||
document.created_from = kwargs.get("created_from", "api")
|
||||
document.created_by = account.id
|
||||
document.created_at = fake.date_time_this_year()
|
||||
document.processing_started_at = kwargs.get("processing_started_at", fake.date_time_this_year())
|
||||
document.file_id = kwargs.get("file_id", fake.uuid4())
|
||||
document.word_count = kwargs.get("word_count", fake.random_int(min=100, max=1000))
|
||||
document.parsing_completed_at = kwargs.get("parsing_completed_at", fake.date_time_this_year())
|
||||
document.cleaning_completed_at = kwargs.get("cleaning_completed_at", fake.date_time_this_year())
|
||||
document.splitting_completed_at = kwargs.get("splitting_completed_at", fake.date_time_this_year())
|
||||
document.tokens = kwargs.get("tokens", fake.random_int(min=50, max=500))
|
||||
document.indexing_latency = kwargs.get("indexing_latency", fake.random_number(digits=3))
|
||||
document.completed_at = kwargs.get("completed_at", fake.date_time_this_year())
|
||||
document.is_paused = kwargs.get("is_paused", False)
|
||||
document.indexing_status = kwargs.get("indexing_status", "completed")
|
||||
document.enabled = kwargs.get("enabled", True)
|
||||
document.archived = kwargs.get("archived", False)
|
||||
document.updated_at = fake.date_time_this_year()
|
||||
document.doc_type = kwargs.get("doc_type", "text")
|
||||
document.doc_metadata = kwargs.get("doc_metadata", {})
|
||||
document.doc_form = kwargs.get("doc_form", IndexType.PARAGRAPH_INDEX)
|
||||
document.doc_language = kwargs.get("doc_language", "en")
|
||||
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
return document
|
||||
|
||||
def _create_test_document_segments(self, db_session_with_containers, document, account, count=3, fake=None):
|
||||
"""
|
||||
Helper method to create test document segments with realistic data.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
document: Document instance for the segments
|
||||
account: Account instance for the segments
|
||||
count: Number of segments to create
|
||||
fake: Faker instance for generating test data
|
||||
|
||||
Returns:
|
||||
list[DocumentSegment]: List of created test document segment instances
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
segments = []
|
||||
|
||||
for i in range(count):
|
||||
segment = DocumentSegment()
|
||||
segment.id = fake.uuid4()
|
||||
segment.tenant_id = document.tenant_id
|
||||
segment.dataset_id = document.dataset_id
|
||||
segment.document_id = document.id
|
||||
segment.position = i + 1
|
||||
segment.content = f"Test segment content {i + 1}: {fake.text(max_nb_chars=200)}"
|
||||
segment.answer = f"Test segment answer {i + 1}: {fake.text(max_nb_chars=100)}"
|
||||
segment.word_count = fake.random_int(min=10, max=100)
|
||||
segment.tokens = fake.random_int(min=5, max=50)
|
||||
segment.keywords = [fake.word() for _ in range(3)]
|
||||
segment.index_node_id = f"node_{fake.uuid4()}"
|
||||
segment.index_node_hash = fake.sha256()
|
||||
segment.hit_count = 0
|
||||
segment.enabled = True
|
||||
segment.status = "completed"
|
||||
segment.created_by = account.id
|
||||
segment.created_at = fake.date_time_this_year()
|
||||
segment.updated_by = account.id
|
||||
segment.updated_at = segment.created_at
|
||||
|
||||
db_session_with_containers.add(segment)
|
||||
segments.append(segment)
|
||||
|
||||
db_session_with_containers.commit()
|
||||
return segments
|
||||
|
||||
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
|
||||
def test_delete_segment_from_index_task_success(self, mock_index_processor_factory, db_session_with_containers):
|
||||
"""
|
||||
Test successful segment deletion from index with comprehensive verification.
|
||||
|
||||
This test verifies:
|
||||
- Proper task execution with valid dataset and document
|
||||
- Index processor factory initialization with correct document form
|
||||
- Index processor clean method called with correct parameters
|
||||
- Database session properly closed after execution
|
||||
- Task completes without exceptions
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
tenant = self._create_test_tenant(db_session_with_containers, fake)
|
||||
account = self._create_test_account(db_session_with_containers, tenant, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
|
||||
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
|
||||
|
||||
# Extract index node IDs for the task
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# Mock the index processor
|
||||
mock_processor = MagicMock()
|
||||
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
|
||||
# Execute the task
|
||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
||||
|
||||
# Verify the task completed successfully
|
||||
assert result is None # Task should return None on success
|
||||
|
||||
# Verify index processor factory was called with correct document form
|
||||
mock_index_processor_factory.assert_called_once_with(document.doc_form)
|
||||
|
||||
# Verify index processor clean method was called with correct parameters
|
||||
# Note: We can't directly compare Dataset objects as they are different instances
|
||||
# from database queries, so we verify the call was made and check the parameters
|
||||
assert mock_processor.clean.call_count == 1
|
||||
call_args = mock_processor.clean.call_args
|
||||
assert call_args[0][0].id == dataset.id # Verify dataset ID matches
|
||||
assert call_args[0][1] == index_node_ids # Verify index node IDs match
|
||||
assert call_args[1]["with_keywords"] is True
|
||||
assert call_args[1]["delete_child_chunks"] is True
|
||||
|
||||
def test_delete_segment_from_index_task_dataset_not_found(self, db_session_with_containers):
|
||||
"""
|
||||
Test task behavior when dataset is not found.
|
||||
|
||||
This test verifies:
|
||||
- Task handles missing dataset gracefully
|
||||
- No index processor operations are attempted
|
||||
- Task returns early without exceptions
|
||||
- Database session is properly closed
|
||||
"""
|
||||
fake = Faker()
|
||||
non_existent_dataset_id = fake.uuid4()
|
||||
non_existent_document_id = fake.uuid4()
|
||||
index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)]
|
||||
|
||||
# Execute the task with non-existent dataset
|
||||
result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id)
|
||||
|
||||
# Verify the task completed without exceptions
|
||||
assert result is None # Task should return None when dataset not found
|
||||
|
||||
def test_delete_segment_from_index_task_document_not_found(self, db_session_with_containers):
|
||||
"""
|
||||
Test task behavior when document is not found.
|
||||
|
||||
This test verifies:
|
||||
- Task handles missing document gracefully
|
||||
- No index processor operations are attempted
|
||||
- Task returns early without exceptions
|
||||
- Database session is properly closed
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
tenant = self._create_test_tenant(db_session_with_containers, fake)
|
||||
account = self._create_test_account(db_session_with_containers, tenant, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
|
||||
|
||||
non_existent_document_id = fake.uuid4()
|
||||
index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)]
|
||||
|
||||
# Execute the task with non-existent document
|
||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id)
|
||||
|
||||
# Verify the task completed without exceptions
|
||||
assert result is None # Task should return None when document not found
|
||||
|
||||
def test_delete_segment_from_index_task_document_disabled(self, db_session_with_containers):
|
||||
"""
|
||||
Test task behavior when document is disabled.
|
||||
|
||||
This test verifies:
|
||||
- Task handles disabled document gracefully
|
||||
- No index processor operations are attempted
|
||||
- Task returns early without exceptions
|
||||
- Database session is properly closed
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data with disabled document
|
||||
tenant = self._create_test_tenant(db_session_with_containers, fake)
|
||||
account = self._create_test_account(db_session_with_containers, tenant, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake, enabled=False)
|
||||
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# Execute the task with disabled document
|
||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
||||
|
||||
# Verify the task completed without exceptions
|
||||
assert result is None # Task should return None when document is disabled
|
||||
|
||||
def test_delete_segment_from_index_task_document_archived(self, db_session_with_containers):
|
||||
"""
|
||||
Test task behavior when document is archived.
|
||||
|
||||
This test verifies:
|
||||
- Task handles archived document gracefully
|
||||
- No index processor operations are attempted
|
||||
- Task returns early without exceptions
|
||||
- Database session is properly closed
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data with archived document
|
||||
tenant = self._create_test_tenant(db_session_with_containers, fake)
|
||||
account = self._create_test_account(db_session_with_containers, tenant, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake, archived=True)
|
||||
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# Execute the task with archived document
|
||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
||||
|
||||
# Verify the task completed without exceptions
|
||||
assert result is None # Task should return None when document is archived
|
||||
|
||||
def test_delete_segment_from_index_task_document_not_completed(self, db_session_with_containers):
|
||||
"""
|
||||
Test task behavior when document indexing is not completed.
|
||||
|
||||
This test verifies:
|
||||
- Task handles incomplete indexing status gracefully
|
||||
- No index processor operations are attempted
|
||||
- Task returns early without exceptions
|
||||
- Database session is properly closed
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data with incomplete indexing
|
||||
tenant = self._create_test_tenant(db_session_with_containers, fake)
|
||||
account = self._create_test_account(db_session_with_containers, tenant, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
|
||||
document = self._create_test_document(
|
||||
db_session_with_containers, dataset, account, fake, indexing_status="indexing"
|
||||
)
|
||||
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# Execute the task with incomplete indexing
|
||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
||||
|
||||
# Verify the task completed without exceptions
|
||||
assert result is None # Task should return None when indexing is not completed
|
||||
|
||||
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
|
||||
def test_delete_segment_from_index_task_index_processor_clean(
|
||||
self, mock_index_processor_factory, db_session_with_containers
|
||||
):
|
||||
"""
|
||||
Test index processor clean method integration with different document forms.
|
||||
|
||||
This test verifies:
|
||||
- Index processor factory creates correct processor for different document forms
|
||||
- Clean method is called with proper parameters for each document form
|
||||
- Task handles different index types correctly
|
||||
- Database session is properly managed
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test different document forms
|
||||
document_forms = [IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX, IndexType.PARENT_CHILD_INDEX]
|
||||
|
||||
for doc_form in document_forms:
|
||||
# Create test data for each document form
|
||||
tenant = self._create_test_tenant(db_session_with_containers, fake)
|
||||
account = self._create_test_account(db_session_with_containers, tenant, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake, doc_form=doc_form)
|
||||
segments = self._create_test_document_segments(db_session_with_containers, document, account, 2, fake)
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# Mock the index processor
|
||||
mock_processor = MagicMock()
|
||||
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
|
||||
# Execute the task
|
||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
||||
|
||||
# Verify the task completed successfully
|
||||
assert result is None
|
||||
|
||||
# Verify index processor factory was called with correct document form
|
||||
mock_index_processor_factory.assert_called_with(doc_form)
|
||||
|
||||
# Verify index processor clean method was called with correct parameters
|
||||
assert mock_processor.clean.call_count == 1
|
||||
call_args = mock_processor.clean.call_args
|
||||
assert call_args[0][0].id == dataset.id # Verify dataset ID matches
|
||||
assert call_args[0][1] == index_node_ids # Verify index node IDs match
|
||||
assert call_args[1]["with_keywords"] is True
|
||||
assert call_args[1]["delete_child_chunks"] is True
|
||||
|
||||
# Reset mocks for next iteration
|
||||
mock_index_processor_factory.reset_mock()
|
||||
mock_processor.reset_mock()
|
||||
|
||||
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
|
||||
def test_delete_segment_from_index_task_exception_handling(
|
||||
self, mock_index_processor_factory, db_session_with_containers
|
||||
):
|
||||
"""
|
||||
Test exception handling in the task.
|
||||
|
||||
This test verifies:
|
||||
- Task handles index processor exceptions gracefully
|
||||
- Database session is properly closed even when exceptions occur
|
||||
- Task logs exceptions appropriately
|
||||
- No unhandled exceptions are raised
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
tenant = self._create_test_tenant(db_session_with_containers, fake)
|
||||
account = self._create_test_account(db_session_with_containers, tenant, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
|
||||
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# Mock the index processor to raise an exception
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.clean.side_effect = Exception("Index processor error")
|
||||
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
|
||||
# Execute the task - should not raise exception
|
||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
||||
|
||||
# Verify the task completed without raising exceptions
|
||||
assert result is None # Task should return None even when exceptions occur
|
||||
|
||||
# Verify index processor clean method was called
|
||||
assert mock_processor.clean.call_count == 1
|
||||
call_args = mock_processor.clean.call_args
|
||||
assert call_args[0][0].id == dataset.id # Verify dataset ID matches
|
||||
assert call_args[0][1] == index_node_ids # Verify index node IDs match
|
||||
assert call_args[1]["with_keywords"] is True
|
||||
assert call_args[1]["delete_child_chunks"] is True
|
||||
|
||||
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
|
||||
def test_delete_segment_from_index_task_empty_index_node_ids(
|
||||
self, mock_index_processor_factory, db_session_with_containers
|
||||
):
|
||||
"""
|
||||
Test task behavior with empty index node IDs list.
|
||||
|
||||
This test verifies:
|
||||
- Task handles empty index node IDs gracefully
|
||||
- Index processor clean method is called with empty list
|
||||
- Task completes successfully
|
||||
- Database session is properly managed
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
tenant = self._create_test_tenant(db_session_with_containers, fake)
|
||||
account = self._create_test_account(db_session_with_containers, tenant, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
|
||||
|
||||
# Use empty index node IDs
|
||||
index_node_ids = []
|
||||
|
||||
# Mock the index processor
|
||||
mock_processor = MagicMock()
|
||||
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
|
||||
# Execute the task
|
||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
||||
|
||||
# Verify the task completed successfully
|
||||
assert result is None
|
||||
|
||||
# Verify index processor clean method was called with empty list
|
||||
assert mock_processor.clean.call_count == 1
|
||||
call_args = mock_processor.clean.call_args
|
||||
assert call_args[0][0].id == dataset.id # Verify dataset ID matches
|
||||
assert call_args[0][1] == index_node_ids # Verify index node IDs match (empty list)
|
||||
assert call_args[1]["with_keywords"] is True
|
||||
assert call_args[1]["delete_child_chunks"] is True
|
||||
|
||||
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
|
||||
def test_delete_segment_from_index_task_large_index_node_ids(
|
||||
self, mock_index_processor_factory, db_session_with_containers
|
||||
):
|
||||
"""
|
||||
Test task behavior with large number of index node IDs.
|
||||
|
||||
This test verifies:
|
||||
- Task handles large lists of index node IDs efficiently
|
||||
- Index processor clean method is called with all node IDs
|
||||
- Task completes successfully with large datasets
|
||||
- Database session is properly managed
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
tenant = self._create_test_tenant(db_session_with_containers, fake)
|
||||
account = self._create_test_account(db_session_with_containers, tenant, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
|
||||
|
||||
# Create large number of segments
|
||||
segments = self._create_test_document_segments(db_session_with_containers, document, account, 50, fake)
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# Mock the index processor
|
||||
mock_processor = MagicMock()
|
||||
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
|
||||
# Execute the task
|
||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
||||
|
||||
# Verify the task completed successfully
|
||||
assert result is None
|
||||
|
||||
# Verify index processor clean method was called with all node IDs
|
||||
assert mock_processor.clean.call_count == 1
|
||||
call_args = mock_processor.clean.call_args
|
||||
assert call_args[0][0].id == dataset.id # Verify dataset ID matches
|
||||
assert call_args[0][1] == index_node_ids # Verify index node IDs match
|
||||
assert call_args[1]["with_keywords"] is True
|
||||
assert call_args[1]["delete_child_chunks"] is True
|
||||
|
||||
# Verify all node IDs were passed
|
||||
assert len(call_args[0][1]) == 50
|
||||
@ -9,7 +9,6 @@ from flask_restx import Api
|
||||
import services.errors.account
|
||||
from controllers.console.auth.error import AuthenticationFailedError
|
||||
from controllers.console.auth.login import LoginApi
|
||||
from controllers.console.error import AccountNotFound
|
||||
|
||||
|
||||
class TestAuthenticationSecurity:
|
||||
@ -27,31 +26,33 @@ class TestAuthenticationSecurity:
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_invalid_email_with_registration_allowed(
|
||||
self, mock_get_invitation, mock_send_email, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
"""Test that invalid email sends reset password email when registration is allowed."""
|
||||
"""Test that invalid email raises AuthenticationFailedError when account not found."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found")
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = True
|
||||
mock_send_email.return_value = "token123"
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
result = login_api.post()
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "fail", "data": "token123", "code": "account_not_found"}
|
||||
mock_send_email.assert_called_once_with(email="nonexistent@example.com", language="en-US")
|
||||
# Assert
|
||||
with pytest.raises(AuthenticationFailedError) as exc_info:
|
||||
login_api.post()
|
||||
|
||||
assert exc_info.value.error_code == "authentication_failed"
|
||||
assert exc_info.value.description == "Invalid email or password."
|
||||
mock_add_rate_limit.assert_called_once_with("nonexistent@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@ -87,16 +88,17 @@ class TestAuthenticationSecurity:
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_invalid_email_with_registration_disabled(
|
||||
self, mock_get_invitation, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
"""Test that invalid email raises AccountNotFound when registration is disabled."""
|
||||
"""Test that invalid email raises AuthenticationFailedError when account not found."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found")
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = False
|
||||
|
||||
@ -107,10 +109,12 @@ class TestAuthenticationSecurity:
|
||||
login_api = LoginApi()
|
||||
|
||||
# Assert
|
||||
with pytest.raises(AccountNotFound) as exc_info:
|
||||
with pytest.raises(AuthenticationFailedError) as exc_info:
|
||||
login_api.post()
|
||||
|
||||
assert exc_info.value.error_code == "account_not_found"
|
||||
assert exc_info.value.error_code == "authentication_failed"
|
||||
assert exc_info.value.description == "Invalid email or password."
|
||||
mock_add_rate_limit.assert_called_once_with("nonexistent@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
|
||||
@ -12,7 +12,7 @@ from controllers.console.auth.oauth import (
|
||||
)
|
||||
from libs.oauth import OAuthUserInfo
|
||||
from models.account import AccountStatus
|
||||
from services.errors.account import AccountNotFoundError
|
||||
from services.errors.account import AccountRegisterError
|
||||
|
||||
|
||||
class TestGetOAuthProviders:
|
||||
@ -451,7 +451,7 @@ class TestAccountGeneration:
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||
if not allow_register and not existing_account:
|
||||
with pytest.raises(AccountNotFoundError):
|
||||
with pytest.raises(AccountRegisterError):
|
||||
_generate_account("github", user_info)
|
||||
else:
|
||||
result = _generate_account("github", user_info)
|
||||
|
||||
@ -29,7 +29,7 @@ class TestHandleMCPRequest:
|
||||
"""Setup test fixtures"""
|
||||
self.app = Mock(spec=App)
|
||||
self.app.name = "test_app"
|
||||
self.app.mode = AppMode.CHAT.value
|
||||
self.app.mode = AppMode.CHAT
|
||||
|
||||
self.mcp_server = Mock(spec=AppMCPServer)
|
||||
self.mcp_server.description = "Test server"
|
||||
@ -196,7 +196,7 @@ class TestIndividualHandlers:
|
||||
def test_handle_list_tools(self):
|
||||
"""Test list tools handler"""
|
||||
app_name = "test_app"
|
||||
app_mode = AppMode.CHAT.value
|
||||
app_mode = AppMode.CHAT
|
||||
description = "Test server"
|
||||
parameters_dict: dict[str, str] = {}
|
||||
user_input_form: list[VariableEntity] = []
|
||||
@ -212,7 +212,7 @@ class TestIndividualHandlers:
|
||||
def test_handle_call_tool(self, mock_app_generate):
|
||||
"""Test call tool handler"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.mode = AppMode.CHAT
|
||||
|
||||
# Create mock request
|
||||
mock_request = Mock()
|
||||
@ -252,7 +252,7 @@ class TestUtilityFunctions:
|
||||
|
||||
def test_build_parameter_schema_chat_mode(self):
|
||||
"""Test building parameter schema for chat mode"""
|
||||
app_mode = AppMode.CHAT.value
|
||||
app_mode = AppMode.CHAT
|
||||
parameters_dict: dict[str, str] = {"name": "Enter your name"}
|
||||
|
||||
user_input_form = [
|
||||
@ -275,7 +275,7 @@ class TestUtilityFunctions:
|
||||
|
||||
def test_build_parameter_schema_workflow_mode(self):
|
||||
"""Test building parameter schema for workflow mode"""
|
||||
app_mode = AppMode.WORKFLOW.value
|
||||
app_mode = AppMode.WORKFLOW
|
||||
parameters_dict: dict[str, str] = {"input_text": "Enter text"}
|
||||
|
||||
user_input_form = [
|
||||
@ -298,7 +298,7 @@ class TestUtilityFunctions:
|
||||
def test_prepare_tool_arguments_chat_mode(self):
|
||||
"""Test preparing tool arguments for chat mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.mode = AppMode.CHAT
|
||||
|
||||
arguments = {"query": "test question", "name": "John"}
|
||||
|
||||
@ -312,7 +312,7 @@ class TestUtilityFunctions:
|
||||
def test_prepare_tool_arguments_workflow_mode(self):
|
||||
"""Test preparing tool arguments for workflow mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
app.mode = AppMode.WORKFLOW
|
||||
|
||||
arguments = {"input_text": "test input"}
|
||||
|
||||
@ -324,7 +324,7 @@ class TestUtilityFunctions:
|
||||
def test_prepare_tool_arguments_completion_mode(self):
|
||||
"""Test preparing tool arguments for completion mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.COMPLETION.value
|
||||
app.mode = AppMode.COMPLETION
|
||||
|
||||
arguments = {"name": "John"}
|
||||
|
||||
@ -336,7 +336,7 @@ class TestUtilityFunctions:
|
||||
def test_extract_answer_from_mapping_response_chat(self):
|
||||
"""Test extracting answer from mapping response for chat mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.mode = AppMode.CHAT
|
||||
|
||||
response = {"answer": "test answer", "other": "data"}
|
||||
|
||||
@ -347,7 +347,7 @@ class TestUtilityFunctions:
|
||||
def test_extract_answer_from_mapping_response_workflow(self):
|
||||
"""Test extracting answer from mapping response for workflow mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
app.mode = AppMode.WORKFLOW
|
||||
|
||||
response = {"data": {"outputs": {"result": "test result"}}}
|
||||
|
||||
|
||||
@ -20,7 +20,6 @@ def test_firecrawl_web_extractor_crawl_mode(mocker):
|
||||
}
|
||||
mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl))
|
||||
job_id = firecrawl_app.crawl_url(url, params)
|
||||
print(f"job_id: {job_id}")
|
||||
|
||||
assert job_id is not None
|
||||
assert isinstance(job_id, str)
|
||||
|
||||
@ -129,7 +129,6 @@ class TestSegmentDumpAndLoad:
|
||||
"""Test basic segment serialization compatibility"""
|
||||
model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")])
|
||||
json = model.model_dump_json()
|
||||
print("Json: ", json)
|
||||
loaded = _Segments.model_validate_json(json)
|
||||
assert loaded == model
|
||||
|
||||
@ -137,7 +136,6 @@ class TestSegmentDumpAndLoad:
|
||||
"""Test number segment serialization compatibility"""
|
||||
model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)])
|
||||
json = model.model_dump_json()
|
||||
print("Json: ", json)
|
||||
loaded = _Segments.model_validate_json(json)
|
||||
assert loaded == model
|
||||
|
||||
@ -145,7 +143,6 @@ class TestSegmentDumpAndLoad:
|
||||
"""Test variable serialization compatibility"""
|
||||
model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")])
|
||||
json = model.model_dump_json()
|
||||
print("Json: ", json)
|
||||
restored = _Variables.model_validate_json(json)
|
||||
assert restored == model
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
from core.tools.utils.yaml_utils import _load_yaml_file
|
||||
from core.variables import (
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
@ -713,4 +713,4 @@ def _load_fixture(fixture_path: Path, fixture_name: str) -> dict[str, Any]:
|
||||
if not fixture_path.exists():
|
||||
raise FileNotFoundError(f"Fixture file not found: {fixture_path}")
|
||||
|
||||
return load_yaml_file(str(fixture_path), ignore_error=False)
|
||||
return _load_yaml_file(file_path=str(fixture_path))
|
||||
|
||||
@ -142,15 +142,11 @@ def test_remove_first_from_array():
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
# Print the variable before running
|
||||
print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
|
||||
|
||||
# Run the node
|
||||
result = list(node.run())
|
||||
|
||||
# Print the variable after running and the result
|
||||
print(f"After: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
|
||||
print(f"Result: {result}")
|
||||
# Completed run
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
|
||||
456
api/tests/unit_tests/core/workflow/test_workflow_entry.py
Normal file
456
api/tests/unit_tests/core/workflow/test_workflow_entry.py
Normal file
@ -0,0 +1,456 @@
|
||||
import pytest
|
||||
|
||||
from core.file.enums import FileType
|
||||
from core.file.models import File, FileTransferMethod
|
||||
from core.variables.variables import StringVariable
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
|
||||
|
||||
class TestWorkflowEntry:
|
||||
"""Test WorkflowEntry class methods."""
|
||||
|
||||
def test_mapping_user_inputs_to_variable_pool_with_system_variables(self):
|
||||
"""Test mapping system variables from user inputs to variable pool."""
|
||||
# Initialize variable pool with system variables
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="test_user_id",
|
||||
app_id="test_app_id",
|
||||
workflow_id="test_workflow_id",
|
||||
),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Define variable mapping - sys variables mapped to other nodes
|
||||
variable_mapping = {
|
||||
"node1.input1": ["node1", "input1"], # Regular mapping
|
||||
"node2.query": ["node2", "query"], # Regular mapping
|
||||
"sys.user_id": ["output_node", "user"], # System variable mapping
|
||||
}
|
||||
|
||||
# User inputs including sys variables
|
||||
user_inputs = {
|
||||
"node1.input1": "new_user_id",
|
||||
"node2.query": "test query",
|
||||
"sys.user_id": "system_user",
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify variables were added to pool
|
||||
# Note: variable_pool.get returns Variable objects, not raw values
|
||||
node1_var = variable_pool.get(["node1", "input1"])
|
||||
assert node1_var is not None
|
||||
assert node1_var.value == "new_user_id"
|
||||
|
||||
node2_var = variable_pool.get(["node2", "query"])
|
||||
assert node2_var is not None
|
||||
assert node2_var.value == "test query"
|
||||
|
||||
# System variable gets mapped to output node
|
||||
output_var = variable_pool.get(["output_node", "user"])
|
||||
assert output_var is not None
|
||||
assert output_var.value == "system_user"
|
||||
|
||||
def test_mapping_user_inputs_to_variable_pool_with_env_variables(self):
|
||||
"""Test mapping environment variables from user inputs to variable pool."""
|
||||
# Initialize variable pool with environment variables
|
||||
env_var = StringVariable(name="API_KEY", value="existing_key")
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
environment_variables=[env_var],
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Add env variable to pool (simulating initialization)
|
||||
variable_pool.add([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"], env_var)
|
||||
|
||||
# Define variable mapping - env variables should not be overridden
|
||||
variable_mapping = {
|
||||
"node1.api_key": [ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"],
|
||||
"node2.new_env": [ENVIRONMENT_VARIABLE_NODE_ID, "NEW_ENV"],
|
||||
}
|
||||
|
||||
# User inputs
|
||||
user_inputs = {
|
||||
"node1.api_key": "user_provided_key", # This should not override existing env var
|
||||
"node2.new_env": "new_env_value",
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify env variable was not overridden
|
||||
env_value = variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"])
|
||||
assert env_value is not None
|
||||
assert env_value.value == "existing_key" # Should remain unchanged
|
||||
|
||||
# New env variables from user input should not be added
|
||||
assert variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "NEW_ENV"]) is None
|
||||
|
||||
def test_mapping_user_inputs_to_variable_pool_with_conversation_variables(self):
|
||||
"""Test mapping conversation variables from user inputs to variable pool."""
|
||||
# Initialize variable pool with conversation variables
|
||||
conv_var = StringVariable(name="last_message", value="Hello")
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
conversation_variables=[conv_var],
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Add conversation variable to pool
|
||||
variable_pool.add([CONVERSATION_VARIABLE_NODE_ID, "last_message"], conv_var)
|
||||
|
||||
# Define variable mapping
|
||||
variable_mapping = {
|
||||
"node1.message": ["node1", "message"], # Map to regular node
|
||||
"conversation.context": ["chat_node", "context"], # Conversation var to regular node
|
||||
}
|
||||
|
||||
# User inputs
|
||||
user_inputs = {
|
||||
"node1.message": "Updated message",
|
||||
"conversation.context": "New context",
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify variables were added to their target nodes
|
||||
node1_var = variable_pool.get(["node1", "message"])
|
||||
assert node1_var is not None
|
||||
assert node1_var.value == "Updated message"
|
||||
|
||||
chat_var = variable_pool.get(["chat_node", "context"])
|
||||
assert chat_var is not None
|
||||
assert chat_var.value == "New context"
|
||||
|
||||
def test_mapping_user_inputs_to_variable_pool_with_regular_variables(self):
|
||||
"""Test mapping regular node variables from user inputs to variable pool."""
|
||||
# Initialize empty variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Define variable mapping for regular nodes
|
||||
variable_mapping = {
|
||||
"input_node.text": ["input_node", "text"],
|
||||
"llm_node.prompt": ["llm_node", "prompt"],
|
||||
"code_node.input": ["code_node", "input"],
|
||||
}
|
||||
|
||||
# User inputs
|
||||
user_inputs = {
|
||||
"input_node.text": "User input text",
|
||||
"llm_node.prompt": "Generate a summary",
|
||||
"code_node.input": {"key": "value"},
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify regular variables were added
|
||||
text_var = variable_pool.get(["input_node", "text"])
|
||||
assert text_var is not None
|
||||
assert text_var.value == "User input text"
|
||||
|
||||
prompt_var = variable_pool.get(["llm_node", "prompt"])
|
||||
assert prompt_var is not None
|
||||
assert prompt_var.value == "Generate a summary"
|
||||
|
||||
input_var = variable_pool.get(["code_node", "input"])
|
||||
assert input_var is not None
|
||||
assert input_var.value == {"key": "value"}
|
||||
|
||||
def test_mapping_user_inputs_with_file_handling(self):
|
||||
"""Test mapping file inputs from user inputs to variable pool."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Define variable mapping
|
||||
variable_mapping = {
|
||||
"file_node.file": ["file_node", "file"],
|
||||
"file_node.files": ["file_node", "files"],
|
||||
}
|
||||
|
||||
# User inputs with file data - using remote_url which doesn't require upload_file_id
|
||||
user_inputs = {
|
||||
"file_node.file": {
|
||||
"type": "document",
|
||||
"transfer_method": "remote_url",
|
||||
"url": "http://example.com/test.pdf",
|
||||
},
|
||||
"file_node.files": [
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "remote_url",
|
||||
"url": "http://example.com/image1.jpg",
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "remote_url",
|
||||
"url": "http://example.com/image2.jpg",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify file was converted and added
|
||||
file_var = variable_pool.get(["file_node", "file"])
|
||||
assert file_var is not None
|
||||
assert file_var.value.type == FileType.DOCUMENT
|
||||
assert file_var.value.transfer_method == FileTransferMethod.REMOTE_URL
|
||||
|
||||
# Verify file list was converted and added
|
||||
files_var = variable_pool.get(["file_node", "files"])
|
||||
assert files_var is not None
|
||||
assert isinstance(files_var.value, list)
|
||||
assert len(files_var.value) == 2
|
||||
assert all(isinstance(f, File) for f in files_var.value)
|
||||
assert files_var.value[0].type == FileType.IMAGE
|
||||
assert files_var.value[1].type == FileType.IMAGE
|
||||
assert files_var.value[0].type == FileType.IMAGE
|
||||
assert files_var.value[1].type == FileType.IMAGE
|
||||
|
||||
def test_mapping_user_inputs_missing_variable_error(self):
|
||||
"""Test that mapping raises error when required variable is missing."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Define variable mapping
|
||||
variable_mapping = {
|
||||
"node1.required_input": ["node1", "required_input"],
|
||||
}
|
||||
|
||||
# User inputs without required variable
|
||||
user_inputs = {
|
||||
"node1.other_input": "some value",
|
||||
}
|
||||
|
||||
# Should raise ValueError for missing variable
|
||||
with pytest.raises(ValueError, match="Variable key node1.required_input not found in user inputs"):
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
def test_mapping_user_inputs_with_alternative_key_format(self):
|
||||
"""Test mapping with alternative key format (without node prefix)."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Define variable mapping
|
||||
variable_mapping = {
|
||||
"node1.input": ["node1", "input"],
|
||||
}
|
||||
|
||||
# User inputs with alternative key format
|
||||
user_inputs = {
|
||||
"input": "value without node prefix", # Alternative format without node prefix
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify variable was added using alternative key
|
||||
input_var = variable_pool.get(["node1", "input"])
|
||||
assert input_var is not None
|
||||
assert input_var.value == "value without node prefix"
|
||||
|
||||
def test_mapping_user_inputs_with_complex_selectors(self):
|
||||
"""Test mapping with complex node variable keys."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Define variable mapping - selectors can only have 2 elements
|
||||
variable_mapping = {
|
||||
"node1.data.field1": ["node1", "data_field1"], # Complex key mapped to simple selector
|
||||
"node2.config.settings.timeout": ["node2", "timeout"], # Complex key mapped to simple selector
|
||||
}
|
||||
|
||||
# User inputs
|
||||
user_inputs = {
|
||||
"node1.data.field1": "nested value",
|
||||
"node2.config.settings.timeout": 30,
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify variables were added with simple selectors
|
||||
data_var = variable_pool.get(["node1", "data_field1"])
|
||||
assert data_var is not None
|
||||
assert data_var.value == "nested value"
|
||||
|
||||
timeout_var = variable_pool.get(["node2", "timeout"])
|
||||
assert timeout_var is not None
|
||||
assert timeout_var.value == 30
|
||||
|
||||
def test_mapping_user_inputs_invalid_node_variable(self):
|
||||
"""Test that mapping handles invalid node variable format."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Define variable mapping with single element node variable (at least one dot is required)
|
||||
variable_mapping = {
|
||||
"singleelement": ["node1", "input"], # No dot separator
|
||||
}
|
||||
|
||||
user_inputs = {"singleelement": "some value"} # Must use exact key
|
||||
|
||||
# Should NOT raise error - function accepts it and uses direct key
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify it was added
|
||||
var = variable_pool.get(["node1", "input"])
|
||||
assert var is not None
|
||||
assert var.value == "some value"
|
||||
|
||||
def test_mapping_all_variable_types_together(self):
|
||||
"""Test mapping all four types of variables in one operation."""
|
||||
# Initialize variable pool with some existing variables
|
||||
env_var = StringVariable(name="API_KEY", value="existing_key")
|
||||
conv_var = StringVariable(name="session_id", value="session123")
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="test_user",
|
||||
app_id="test_app",
|
||||
query="initial query",
|
||||
),
|
||||
environment_variables=[env_var],
|
||||
conversation_variables=[conv_var],
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Add existing variables to pool
|
||||
variable_pool.add([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"], env_var)
|
||||
variable_pool.add([CONVERSATION_VARIABLE_NODE_ID, "session_id"], conv_var)
|
||||
|
||||
# Define comprehensive variable mapping
|
||||
variable_mapping = {
|
||||
# System variables mapped to regular nodes
|
||||
"sys.user_id": ["start", "user"],
|
||||
"sys.app_id": ["start", "app"],
|
||||
# Environment variables (won't be overridden)
|
||||
"env.API_KEY": ["config", "api_key"],
|
||||
# Conversation variables mapped to regular nodes
|
||||
"conversation.session_id": ["chat", "session"],
|
||||
# Regular variables
|
||||
"input.text": ["input", "text"],
|
||||
"process.data": ["process", "data"],
|
||||
}
|
||||
|
||||
# User inputs
|
||||
user_inputs = {
|
||||
"sys.user_id": "new_user",
|
||||
"sys.app_id": "new_app",
|
||||
"env.API_KEY": "attempted_override", # Should not override env var
|
||||
"conversation.session_id": "new_session",
|
||||
"input.text": "user input text",
|
||||
"process.data": {"value": 123, "status": "active"},
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify system variables were added to their target nodes
|
||||
start_user = variable_pool.get(["start", "user"])
|
||||
assert start_user is not None
|
||||
assert start_user.value == "new_user"
|
||||
|
||||
start_app = variable_pool.get(["start", "app"])
|
||||
assert start_app is not None
|
||||
assert start_app.value == "new_app"
|
||||
|
||||
# Verify env variable was not overridden (still has original value)
|
||||
env_value = variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"])
|
||||
assert env_value is not None
|
||||
assert env_value.value == "existing_key"
|
||||
|
||||
# Environment variables get mapped to other nodes even when they exist in env pool
|
||||
# But the original env value remains unchanged
|
||||
config_api_key = variable_pool.get(["config", "api_key"])
|
||||
assert config_api_key is not None
|
||||
assert config_api_key.value == "attempted_override"
|
||||
|
||||
# Verify conversation variable was mapped to target node
|
||||
chat_session = variable_pool.get(["chat", "session"])
|
||||
assert chat_session is not None
|
||||
assert chat_session.value == "new_session"
|
||||
|
||||
# Verify regular variables were added
|
||||
input_text = variable_pool.get(["input", "text"])
|
||||
assert input_text is not None
|
||||
assert input_text.value == "user input text"
|
||||
|
||||
process_data = variable_pool.get(["process", "data"])
|
||||
assert process_data is not None
|
||||
assert process_data.value == {"value": 123, "status": "active"}
|
||||
@ -11,12 +11,12 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_init_success_with_all_config(self):
|
||||
"""Test successful initialization when all required config is provided."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
@ -31,7 +31,7 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_init_raises_error_when_url_missing(self):
|
||||
"""Test initialization raises ValueError when SUPABASE_URL is None."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = None
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
@ -41,7 +41,7 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_init_raises_error_when_api_key_missing(self):
|
||||
"""Test initialization raises ValueError when SUPABASE_API_KEY is None."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = None
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
@ -51,7 +51,7 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_init_raises_error_when_bucket_name_missing(self):
|
||||
"""Test initialization raises ValueError when SUPABASE_BUCKET_NAME is None."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = None
|
||||
@ -61,12 +61,12 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_create_bucket_when_not_exists(self):
|
||||
"""Test create_bucket creates bucket when it doesn't exist."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
@ -77,12 +77,12 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_create_bucket_when_exists(self):
|
||||
"""Test create_bucket does not create bucket when it already exists."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
@ -94,12 +94,12 @@ class TestSupabaseStorage:
|
||||
@pytest.fixture
|
||||
def storage_with_mock_client(self):
|
||||
"""Fixture providing SupabaseStorage with mocked client."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
@ -251,12 +251,12 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_bucket_exists_returns_true_when_bucket_found(self):
|
||||
"""Test bucket_exists returns True when bucket is found in list."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
@ -271,12 +271,12 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_bucket_exists_returns_false_when_bucket_not_found(self):
|
||||
"""Test bucket_exists returns False when bucket is not found in list."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
@ -294,12 +294,12 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_bucket_exists_returns_false_when_no_buckets(self):
|
||||
"""Test bucket_exists returns False when no buckets exist."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@ from services.account_service import AccountService, RegisterService, TenantServ
|
||||
from services.errors.account import (
|
||||
AccountAlreadyInTenantError,
|
||||
AccountLoginError,
|
||||
AccountNotFoundError,
|
||||
AccountPasswordError,
|
||||
AccountRegisterError,
|
||||
CurrentPasswordIncorrectError,
|
||||
@ -195,7 +194,7 @@ class TestAccountService:
|
||||
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(
|
||||
AccountNotFoundError, AccountService.authenticate, "notfound@example.com", "password"
|
||||
AccountPasswordError, AccountService.authenticate, "notfound@example.com", "password"
|
||||
)
|
||||
|
||||
def test_authenticate_account_banned(self, mock_db_dependencies):
|
||||
|
||||
@ -66,7 +66,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app_id"
|
||||
app_model.tenant_id = "tenant_id"
|
||||
app_model.mode = AppMode.CHAT.value
|
||||
app_model.mode = AppMode.CHAT
|
||||
|
||||
api_based_extension_id = "api_based_extension_id"
|
||||
mock_api_based_extension = APIBasedExtension(
|
||||
@ -127,7 +127,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app_id"
|
||||
app_model.tenant_id = "tenant_id"
|
||||
app_model.mode = AppMode.WORKFLOW.value
|
||||
app_model.mode = AppMode.WORKFLOW
|
||||
|
||||
api_based_extension_id = "api_based_extension_id"
|
||||
mock_api_based_extension = APIBasedExtension(
|
||||
|
||||
@ -279,8 +279,6 @@ def test_structured_output_parser():
|
||||
]
|
||||
|
||||
for case in testcases:
|
||||
print(f"Running test case: {case['name']}")
|
||||
|
||||
# Setup model entity
|
||||
model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"])
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from textwrap import dedent
|
||||
import pytest
|
||||
from yaml import YAMLError
|
||||
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
from core.tools.utils.yaml_utils import _load_yaml_file
|
||||
|
||||
EXAMPLE_YAML_FILE = "example_yaml.yaml"
|
||||
INVALID_YAML_FILE = "invalid_yaml.yaml"
|
||||
@ -56,15 +56,15 @@ def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
|
||||
|
||||
|
||||
def test_load_yaml_non_existing_file():
|
||||
assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
|
||||
assert load_yaml_file(file_path="") == {}
|
||||
with pytest.raises(FileNotFoundError):
|
||||
_load_yaml_file(file_path=NON_EXISTING_YAML_FILE)
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False)
|
||||
_load_yaml_file(file_path="")
|
||||
|
||||
|
||||
def test_load_valid_yaml_file(prepare_example_yaml_file):
|
||||
yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
|
||||
yaml_data = _load_yaml_file(file_path=prepare_example_yaml_file)
|
||||
assert len(yaml_data) > 0
|
||||
assert yaml_data["age"] == 30
|
||||
assert yaml_data["gender"] == "male"
|
||||
@ -77,7 +77,4 @@ def test_load_valid_yaml_file(prepare_example_yaml_file):
|
||||
def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
|
||||
# yaml syntax error
|
||||
with pytest.raises(YAMLError):
|
||||
load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=False)
|
||||
|
||||
# ignore error
|
||||
assert load_yaml_file(file_path=prepare_invalid_yaml_file) == {}
|
||||
_load_yaml_file(file_path=prepare_invalid_yaml_file)
|
||||
|
||||
Reference in New Issue
Block a user