Merge branch 'main' into feat/mcp-06-18

This commit is contained in:
Novice
2025-10-13 13:54:01 +08:00
364 changed files with 7548 additions and 3282 deletions

View File

@ -25,7 +25,7 @@ class TestChatMessageApiPermissions:
"""Create a mock App model for testing."""
app = App()
app.id = str(uuid.uuid4())
app.mode = AppMode.CHAT.value
app.mode = AppMode.CHAT
app.tenant_id = str(uuid.uuid4())
app.status = "normal"
return app
@ -33,17 +33,19 @@ class TestChatMessageApiPermissions:
@pytest.fixture
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
"""Create a mock Account for testing."""
account = Account()
account.id = str(uuid.uuid4())
account.name = "Test User"
account.email = "test@example.com"
account = Account(
name="Test User",
email="test@example.com",
)
account.last_active_at = naive_utc_now()
account.created_at = naive_utc_now()
account.updated_at = naive_utc_now()
account.id = str(uuid.uuid4())
tenant = Tenant()
# Create mock tenant
tenant = Tenant(name="Test Tenant")
tenant.id = str(uuid.uuid4())
tenant.name = "Test Tenant"
mock_session_instance = mock.Mock()

View File

@ -23,7 +23,7 @@ class TestModelConfigResourcePermissions:
"""Create a mock App model for testing."""
app = App()
app.id = str(uuid.uuid4())
app.mode = AppMode.CHAT.value
app.mode = AppMode.CHAT
app.tenant_id = str(uuid.uuid4())
app.status = "normal"
app.app_model_config_id = str(uuid.uuid4())
@ -32,17 +32,16 @@ class TestModelConfigResourcePermissions:
@pytest.fixture
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
"""Create a mock Account for testing."""
account = Account()
account = Account(name="Test User", email="test@example.com")
account.id = str(uuid.uuid4())
account.name = "Test User"
account.email = "test@example.com"
account.last_active_at = naive_utc_now()
account.created_at = naive_utc_now()
account.updated_at = naive_utc_now()
tenant = Tenant()
# Create mock tenant
tenant = Tenant(name="Test Tenant")
tenant.id = str(uuid.uuid4())
tenant.name = "Test Tenant"
mock_session_instance = mock.Mock()

View File

@ -542,7 +542,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
index=1,
node_execution_id=str(uuid.uuid4()),
node_id=self._node_id,
node_type=NodeType.LLM.value,
node_type=NodeType.LLM,
title="Test Node",
inputs='{"input": "test input"}',
process_data='{"test_var": "process_value", "other_var": "other_process"}',

View File

@ -36,7 +36,7 @@ def test_api_tool(setup_http_mock):
entity=ToolEntity(
identity=ToolIdentity(provider="", author="", name="", label=I18nObject(en_US="test tool")),
),
api_bundle=ApiToolBundle(**tool_bundle),
api_bundle=ApiToolBundle.model_validate(tool_bundle),
runtime=ToolRuntime(tenant_id="", credentials={"auth_type": "none"}),
provider_id="test_tool",
)

View File

@ -1,5 +1,6 @@
import os
from collections import UserDict
from typing import Any
from unittest.mock import MagicMock
import pytest
@ -9,7 +10,6 @@ from pymochow.model.database import Database # type: ignore
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore
from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore
from pymochow.model.table import Table # type: ignore
from requests.adapters import HTTPAdapter
class AttrDict(UserDict):
@ -21,7 +21,7 @@ class MockBaiduVectorDBClass:
def mock_vector_db_client(
self,
config=None,
adapter: HTTPAdapter | None = None,
adapter: Any | None = None,
):
self.conn = MagicMock()
self._config = MagicMock()

View File

@ -44,25 +44,25 @@ class MockClient:
"hits": [
{
"_source": {
Field.CONTENT_KEY.value: "abcdef",
Field.VECTOR.value: [1, 2],
Field.METADATA_KEY.value: {},
Field.CONTENT_KEY: "abcdef",
Field.VECTOR: [1, 2],
Field.METADATA_KEY: {},
},
"_score": 1.0,
},
{
"_source": {
Field.CONTENT_KEY.value: "123456",
Field.VECTOR.value: [2, 2],
Field.METADATA_KEY.value: {},
Field.CONTENT_KEY: "123456",
Field.VECTOR: [2, 2],
Field.METADATA_KEY: {},
},
"_score": 0.9,
},
{
"_source": {
Field.CONTENT_KEY.value: "a1b2c3",
Field.VECTOR.value: [3, 2],
Field.METADATA_KEY.value: {},
Field.CONTENT_KEY: "a1b2c3",
Field.VECTOR: [3, 2],
Field.METADATA_KEY: {},
},
"_score": 0.8,
},

View File

@ -1,9 +1,8 @@
import os
from typing import Union
from typing import Any, Union
import pytest
from _pytest.monkeypatch import MonkeyPatch
from requests.adapters import HTTPAdapter
from tcvectordb import RPCVectorDBClient # type: ignore
from tcvectordb.model import enum
from tcvectordb.model.collection import FilterIndexConfig
@ -23,7 +22,7 @@ class MockTcvectordbClass:
key="",
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
timeout=10,
adapter: HTTPAdapter | None = None,
adapter: Any | None = None,
pool_size: int = 2,
proxies: dict | None = None,
password: str | None = None,

View File

@ -40,13 +40,13 @@ class MockVikingDBClass:
collection_name=collection_name,
description="Collection For Dify",
viking_db_service=self._viking_db_service,
primary_key=vdb_Field.PRIMARY_KEY.value,
primary_key=vdb_Field.PRIMARY_KEY,
fields=[
Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=768),
Field(field_name=vdb_Field.PRIMARY_KEY, field_type=FieldType.String, is_primary_key=True),
Field(field_name=vdb_Field.METADATA_KEY, field_type=FieldType.String),
Field(field_name=vdb_Field.GROUP_KEY, field_type=FieldType.String),
Field(field_name=vdb_Field.CONTENT_KEY, field_type=FieldType.Text),
Field(field_name=vdb_Field.VECTOR, field_type=FieldType.Vector, dim=768),
],
indexes=[
Index(
@ -71,7 +71,7 @@ class MockVikingDBClass:
return Collection(
collection_name=collection_name,
description=description,
primary_key=vdb_Field.PRIMARY_KEY.value,
primary_key=vdb_Field.PRIMARY_KEY,
viking_db_service=self._viking_db_service,
fields=fields,
)
@ -126,11 +126,11 @@ class MockVikingDBClass:
def fetch_data(self, id: Union[str, list[str], int, list[int]]):
return Data(
fields={
vdb_Field.GROUP_KEY.value: "test_group",
vdb_Field.METADATA_KEY.value: "{}",
vdb_Field.CONTENT_KEY.value: "content",
vdb_Field.PRIMARY_KEY.value: id,
vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
vdb_Field.GROUP_KEY: "test_group",
vdb_Field.METADATA_KEY: "{}",
vdb_Field.CONTENT_KEY: "content",
vdb_Field.PRIMARY_KEY: id,
vdb_Field.VECTOR: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
},
id=id,
)
@ -151,16 +151,16 @@ class MockVikingDBClass:
return [
Data(
fields={
vdb_Field.GROUP_KEY.value: "test_group",
vdb_Field.METADATA_KEY.value: '\
vdb_Field.GROUP_KEY: "test_group",
vdb_Field.METADATA_KEY: '\
{"source": "/var/folders/ml/xxx/xxx.txt", \
"document_id": "test_document_id", \
"dataset_id": "test_dataset_id", \
"doc_id": "test_id", \
"doc_hash": "test_hash"}',
vdb_Field.CONTENT_KEY.value: "content",
vdb_Field.PRIMARY_KEY.value: "test_id",
vdb_Field.VECTOR.value: vector,
vdb_Field.CONTENT_KEY: "content",
vdb_Field.PRIMARY_KEY: "test_id",
vdb_Field.VECTOR: vector,
},
id="test_id",
score=0.10,
@ -173,16 +173,16 @@ class MockVikingDBClass:
return [
Data(
fields={
vdb_Field.GROUP_KEY.value: "test_group",
vdb_Field.METADATA_KEY.value: '\
vdb_Field.GROUP_KEY: "test_group",
vdb_Field.METADATA_KEY: '\
{"source": "/var/folders/ml/xxx/xxx.txt", \
"document_id": "test_document_id", \
"dataset_id": "test_dataset_id", \
"doc_id": "test_id", \
"doc_hash": "test_hash"}',
vdb_Field.CONTENT_KEY.value: "content",
vdb_Field.PRIMARY_KEY.value: "test_id",
vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
vdb_Field.CONTENT_KEY: "content",
vdb_Field.PRIMARY_KEY: "test_id",
vdb_Field.VECTOR: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
},
id="test_id",
score=0.10,

View File

@ -129,8 +129,8 @@ class TestOpenSearchVector:
"hits": [
{
"_source": {
Field.CONTENT_KEY.value: get_example_text(),
Field.METADATA_KEY.value: {"document_id": self.example_doc_id},
Field.CONTENT_KEY: get_example_text(),
Field.METADATA_KEY: {"document_id": self.example_doc_id},
},
"_score": 1.0,
}

View File

@ -16,6 +16,7 @@ from services.errors.account import (
AccountPasswordError,
AccountRegisterError,
CurrentPasswordIncorrectError,
TenantNotFoundError,
)
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
@ -63,7 +64,7 @@ class TestAccountService:
password=password,
)
assert account.email == email
assert account.status == AccountStatus.ACTIVE.value
assert account.status == AccountStatus.ACTIVE
# Login with correct password
logged_in = AccountService.authenticate(email, password)
@ -184,7 +185,7 @@ class TestAccountService:
)
# Ban the account
account.status = AccountStatus.BANNED.value
account.status = AccountStatus.BANNED
from extensions.ext_database import db
db.session.commit()
@ -268,14 +269,14 @@ class TestAccountService:
interface_language="en-US",
password=password,
)
account.status = AccountStatus.PENDING.value
account.status = AccountStatus.PENDING
from extensions.ext_database import db
db.session.commit()
# Authenticate should activate the account
authenticated_account = AccountService.authenticate(email, password)
assert authenticated_account.status == AccountStatus.ACTIVE.value
assert authenticated_account.status == AccountStatus.ACTIVE
assert authenticated_account.initialized_at is not None
def test_update_account_password_success(self, db_session_with_containers, mock_external_service_dependencies):
@ -538,7 +539,7 @@ class TestAccountService:
from extensions.ext_database import db
db.session.refresh(account)
assert account.status == AccountStatus.CLOSED.value
assert account.status == AccountStatus.CLOSED
def test_update_account_fields(self, db_session_with_containers, mock_external_service_dependencies):
"""
@ -678,7 +679,7 @@ class TestAccountService:
interface_language="en-US",
password=password,
)
account.status = AccountStatus.PENDING.value
account.status = AccountStatus.PENDING
from extensions.ext_database import db
db.session.commit()
@ -687,7 +688,7 @@ class TestAccountService:
token_pair = AccountService.login(account)
db.session.refresh(account)
assert account.status == AccountStatus.ACTIVE.value
assert account.status == AccountStatus.ACTIVE
def test_logout(self, db_session_with_containers, mock_external_service_dependencies):
"""
@ -859,7 +860,7 @@ class TestAccountService:
)
# Ban the account
account.status = AccountStatus.BANNED.value
account.status = AccountStatus.BANNED
from extensions.ext_database import db
db.session.commit()
@ -989,7 +990,7 @@ class TestAccountService:
)
# Ban the account
account.status = AccountStatus.BANNED.value
account.status = AccountStatus.BANNED
from extensions.ext_database import db
db.session.commit()
@ -1414,7 +1415,7 @@ class TestTenantService:
)
# Try to get current tenant (should fail)
with pytest.raises(AttributeError):
with pytest.raises((AttributeError, TenantNotFoundError)):
TenantService.get_current_tenant_by_account(account)
def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies):

View File

@ -86,7 +86,7 @@ class TestFileService:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
@ -187,7 +187,7 @@ class TestFileService:
assert upload_file.extension == "pdf"
assert upload_file.mime_type == mimetype
assert upload_file.created_by == account.id
assert upload_file.created_by_role == CreatorUserRole.ACCOUNT.value
assert upload_file.created_by_role == CreatorUserRole.ACCOUNT
assert upload_file.used is False
assert upload_file.hash == hashlib.sha3_256(content).hexdigest()
@ -216,7 +216,7 @@ class TestFileService:
assert upload_file is not None
assert upload_file.created_by == end_user.id
assert upload_file.created_by_role == CreatorUserRole.END_USER.value
assert upload_file.created_by_role == CreatorUserRole.END_USER
def test_upload_file_with_datasets_source(
self, db_session_with_containers, engine, mock_external_service_dependencies

View File

@ -72,7 +72,7 @@ class TestMetadataService:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)

View File

@ -103,7 +103,7 @@ class TestModelLoadBalancingService:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)

View File

@ -67,7 +67,7 @@ class TestModelProviderService:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)

View File

@ -66,7 +66,7 @@ class TestTagService:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)

View File

@ -144,7 +144,7 @@ class TestWebConversationService:
system_instruction=fake.text(max_nb_chars=300),
system_instruction_tokens=50,
status="normal",
invoke_from=InvokeFrom.WEB_APP.value,
invoke_from=InvokeFrom.WEB_APP,
from_source="console" if isinstance(user, Account) else "api",
from_end_user_id=user.id if isinstance(user, EndUser) else None,
from_account_id=user.id if isinstance(user, Account) else None,

View File

@ -87,7 +87,7 @@ class TestWebAppAuthService:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
@ -150,7 +150,7 @@ class TestWebAppAuthService:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
@ -232,7 +232,7 @@ class TestWebAppAuthService:
assert result.id == account.id
assert result.email == account.email
assert result.name == account.name
assert result.status == AccountStatus.ACTIVE.value
assert result.status == AccountStatus.ACTIVE
# Verify database state
from extensions.ext_database import db
@ -280,7 +280,7 @@ class TestWebAppAuthService:
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status=AccountStatus.BANNED.value,
status=AccountStatus.BANNED,
)
# Hash password
@ -411,7 +411,7 @@ class TestWebAppAuthService:
assert result.id == account.id
assert result.email == account.email
assert result.name == account.name
assert result.status == AccountStatus.ACTIVE.value
assert result.status == AccountStatus.ACTIVE
# Verify database state
from extensions.ext_database import db
@ -455,7 +455,7 @@ class TestWebAppAuthService:
email=unique_email,
name=fake.name(),
interface_language="en-US",
status=AccountStatus.BANNED.value,
status=AccountStatus.BANNED,
)
from extensions.ext_database import db

View File

@ -199,7 +199,7 @@ class TestWorkflowAppService:
elapsed_time=1.5,
total_tokens=100,
total_steps=3,
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
finished_at=datetime.now(UTC),
@ -215,7 +215,7 @@ class TestWorkflowAppService:
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
@ -356,7 +356,7 @@ class TestWorkflowAppService:
elapsed_time=1.0 + i,
total_tokens=100 + i * 10,
total_steps=3,
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC) + timedelta(minutes=i),
finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None,
@ -371,7 +371,7 @@ class TestWorkflowAppService:
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC) + timedelta(minutes=i),
)
@ -464,7 +464,7 @@ class TestWorkflowAppService:
elapsed_time=1.0,
total_tokens=100,
total_steps=3,
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=timestamp,
finished_at=timestamp + timedelta(minutes=1),
@ -479,7 +479,7 @@ class TestWorkflowAppService:
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=timestamp,
)
@ -571,7 +571,7 @@ class TestWorkflowAppService:
elapsed_time=1.0,
total_tokens=100,
total_steps=3,
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC) + timedelta(minutes=i),
finished_at=datetime.now(UTC) + timedelta(minutes=i + 1),
@ -586,7 +586,7 @@ class TestWorkflowAppService:
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC) + timedelta(minutes=i),
)
@ -701,7 +701,7 @@ class TestWorkflowAppService:
elapsed_time=1.0,
total_tokens=100,
total_steps=3,
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC) + timedelta(minutes=i),
finished_at=datetime.now(UTC) + timedelta(minutes=i + 1),
@ -716,7 +716,7 @@ class TestWorkflowAppService:
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC) + timedelta(minutes=i),
)
@ -743,7 +743,7 @@ class TestWorkflowAppService:
elapsed_time=1.0,
total_tokens=100,
total_steps=3,
created_by_role=CreatorUserRole.END_USER.value,
created_by_role=CreatorUserRole.END_USER,
created_by=end_user.id,
created_at=datetime.now(UTC) + timedelta(minutes=i + 10),
finished_at=datetime.now(UTC) + timedelta(minutes=i + 11),
@ -758,7 +758,7 @@ class TestWorkflowAppService:
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="web-app",
created_by_role=CreatorUserRole.END_USER.value,
created_by_role=CreatorUserRole.END_USER,
created_by=end_user.id,
created_at=datetime.now(UTC) + timedelta(minutes=i + 10),
)
@ -780,14 +780,14 @@ class TestWorkflowAppService:
limit=20,
)
assert result_session_filter["total"] == 2
assert all(log.created_by_role == CreatorUserRole.END_USER.value for log in result_session_filter["data"])
assert all(log.created_by_role == CreatorUserRole.END_USER for log in result_session_filter["data"])
# Test filtering by account email
result_account_filter = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, created_by_account=account.email, page=1, limit=20
)
assert result_account_filter["total"] == 3
assert all(log.created_by_role == CreatorUserRole.ACCOUNT.value for log in result_account_filter["data"])
assert all(log.created_by_role == CreatorUserRole.ACCOUNT for log in result_account_filter["data"])
# Test filtering by non-existent session ID
result_no_session = service.get_paginate_workflow_app_logs(
@ -853,7 +853,7 @@ class TestWorkflowAppService:
elapsed_time=1.0,
total_tokens=100,
total_steps=3,
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
finished_at=datetime.now(UTC) + timedelta(minutes=1),
@ -869,7 +869,7 @@ class TestWorkflowAppService:
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
@ -943,7 +943,7 @@ class TestWorkflowAppService:
elapsed_time=0.0, # Edge case: 0 elapsed time
total_tokens=0, # Edge case: 0 tokens
total_steps=0, # Edge case: 0 steps
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
finished_at=datetime.now(UTC),
@ -959,7 +959,7 @@ class TestWorkflowAppService:
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
@ -1098,7 +1098,7 @@ class TestWorkflowAppService:
elapsed_time=1.5,
total_tokens=100,
total_steps=3,
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC) + timedelta(minutes=i),
finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status == "succeeded" else None,
@ -1113,7 +1113,7 @@ class TestWorkflowAppService:
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC) + timedelta(minutes=i),
)
@ -1198,7 +1198,7 @@ class TestWorkflowAppService:
elapsed_time=1.5,
total_tokens=100,
total_steps=3,
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC) + timedelta(minutes=i),
finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None,
@ -1213,7 +1213,7 @@ class TestWorkflowAppService:
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC) + timedelta(minutes=i),
)
@ -1300,7 +1300,7 @@ class TestWorkflowAppService:
elapsed_time=1.5,
total_tokens=100,
total_steps=3,
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC) + timedelta(minutes=i * 10 + j),
finished_at=datetime.now(UTC) + timedelta(minutes=i * 10 + j + 1),
@ -1315,7 +1315,7 @@ class TestWorkflowAppService:
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC) + timedelta(minutes=i * 10 + j),
)

View File

@ -130,7 +130,7 @@ class TestWorkflowRunService:
elapsed_time=1.5,
total_tokens=100,
total_steps=3,
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=created_time,
finished_at=created_time,
@ -167,7 +167,7 @@ class TestWorkflowRunService:
inputs={},
status="normal",
mode="chat",
from_source=CreatorUserRole.ACCOUNT.value,
from_source=CreatorUserRole.ACCOUNT,
from_account_id=account.id,
)
db.session.add(conversation)
@ -188,7 +188,7 @@ class TestWorkflowRunService:
message.answer_price_unit = 0.001
message.currency = "USD"
message.status = "normal"
message.from_source = CreatorUserRole.ACCOUNT.value
message.from_source = CreatorUserRole.ACCOUNT
message.from_account_id = account.id
message.workflow_run_id = workflow_run.id
message.inputs = {"input": "test input"}
@ -458,7 +458,7 @@ class TestWorkflowRunService:
status="succeeded",
elapsed_time=0.5,
execution_metadata=json.dumps({"tokens": 50}),
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
@ -689,7 +689,7 @@ class TestWorkflowRunService:
status="succeeded",
elapsed_time=0.5,
execution_metadata=json.dumps({"tokens": 50}),
created_by_role=CreatorUserRole.END_USER.value,
created_by_role=CreatorUserRole.END_USER,
created_by=end_user.id,
created_at=datetime.now(UTC),
)
@ -710,4 +710,4 @@ class TestWorkflowRunService:
assert node_exec.app_id == app.id
assert node_exec.workflow_run_id == workflow_run.id
assert node_exec.created_by == end_user.id
assert node_exec.created_by_role == CreatorUserRole.END_USER.value
assert node_exec.created_by_role == CreatorUserRole.END_USER

View File

@ -44,27 +44,26 @@ class TestWorkflowService:
Account: Created test account instance
"""
fake = fake or Faker()
account = Account()
account.id = fake.uuid4()
account.email = fake.email()
account.name = fake.name()
account.avatar_url = fake.url()
account.tenant_id = fake.uuid4()
account.status = "active"
account.type = "normal"
account.role = "owner"
account.interface_language = "en-US" # Set interface language for Site creation
account = Account(
email=fake.email(),
name=fake.name(),
avatar=fake.url(),
status="active",
interface_language="en-US", # Set interface language for Site creation
)
account.created_at = fake.date_time_this_year()
account.id = fake.uuid4()
account.updated_at = account.created_at
# Create a tenant for the account
from models.account import Tenant
tenant = Tenant()
tenant.id = account.tenant_id
tenant.name = f"Test Tenant {fake.company()}"
tenant.plan = "basic"
tenant.status = "active"
tenant = Tenant(
name=f"Test Tenant {fake.company()}",
plan="basic",
status="active",
)
tenant.id = account.current_tenant_id
tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at
@ -91,20 +90,21 @@ class TestWorkflowService:
App: Created test app instance
"""
fake = fake or Faker()
app = App()
app.id = fake.uuid4()
app.tenant_id = fake.uuid4()
app.name = fake.company()
app.description = fake.text()
app.mode = AppMode.WORKFLOW
app.icon_type = "emoji"
app.icon = "🤖"
app.icon_background = "#FFEAD5"
app.enable_site = True
app.enable_api = True
app.created_by = fake.uuid4()
app = App(
id=fake.uuid4(),
tenant_id=fake.uuid4(),
name=fake.company(),
description=fake.text(),
mode=AppMode.WORKFLOW,
icon_type="emoji",
icon="🤖",
icon_background="#FFEAD5",
enable_site=True,
enable_api=True,
created_by=fake.uuid4(),
workflow_id=None, # Will be set when workflow is created
)
app.updated_by = app.created_by
app.workflow_id = None # Will be set when workflow is created
from extensions.ext_database import db
@ -126,19 +126,20 @@ class TestWorkflowService:
Workflow: Created test workflow instance
"""
fake = fake or Faker()
workflow = Workflow()
workflow.id = fake.uuid4()
workflow.tenant_id = app.tenant_id
workflow.app_id = app.id
workflow.type = WorkflowType.WORKFLOW.value
workflow.version = Workflow.VERSION_DRAFT
workflow.graph = json.dumps({"nodes": [], "edges": []})
workflow.features = json.dumps({"features": []})
# unique_hash is a computed property based on graph and features
workflow.created_by = account.id
workflow.updated_by = account.id
workflow.environment_variables = []
workflow.conversation_variables = []
workflow = Workflow(
id=fake.uuid4(),
tenant_id=app.tenant_id,
app_id=app.id,
type=WorkflowType.WORKFLOW,
version=Workflow.VERSION_DRAFT,
graph=json.dumps({"nodes": [], "edges": []}),
features=json.dumps({"features": []}),
# unique_hash is a computed property based on graph and features
created_by=account.id,
updated_by=account.id,
environment_variables=[],
conversation_variables=[],
)
from extensions.ext_database import db
@ -175,7 +176,7 @@ class TestWorkflowService:
node_execution.node_type = "test_node"
node_execution.title = "Test Node" # Required field
node_execution.status = "succeeded"
node_execution.created_by_role = CreatorUserRole.ACCOUNT.value # Required field
node_execution.created_by_role = CreatorUserRole.ACCOUNT # Required field
node_execution.created_by = account.id # Required field
node_execution.created_at = fake.date_time_this_year()

View File

@ -69,7 +69,7 @@ class TestWorkspaceService:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
@ -111,7 +111,7 @@ class TestWorkspaceService:
assert result["name"] == tenant.name
assert result["plan"] == tenant.plan
assert result["status"] == tenant.status
assert result["role"] == TenantAccountRole.OWNER.value
assert result["role"] == TenantAccountRole.OWNER
assert result["created_at"] == tenant.created_at
assert result["trial_end_reason"] is None
@ -159,7 +159,7 @@ class TestWorkspaceService:
assert result["name"] == tenant.name
assert result["plan"] == tenant.plan
assert result["status"] == tenant.status
assert result["role"] == TenantAccountRole.OWNER.value
assert result["role"] == TenantAccountRole.OWNER
assert result["created_at"] == tenant.created_at
assert result["trial_end_reason"] is None
@ -194,7 +194,7 @@ class TestWorkspaceService:
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join.role = TenantAccountRole.NORMAL.value
join.role = TenantAccountRole.NORMAL
db.session.commit()
# Setup mocks for feature service
@ -212,7 +212,7 @@ class TestWorkspaceService:
assert result["name"] == tenant.name
assert result["plan"] == tenant.plan
assert result["status"] == tenant.status
assert result["role"] == TenantAccountRole.NORMAL.value
assert result["role"] == TenantAccountRole.NORMAL
assert result["created_at"] == tenant.created_at
assert result["trial_end_reason"] is None
@ -245,7 +245,7 @@ class TestWorkspaceService:
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join.role = TenantAccountRole.ADMIN.value
join.role = TenantAccountRole.ADMIN
db.session.commit()
# Setup mocks for feature service and tenant service
@ -260,7 +260,7 @@ class TestWorkspaceService:
# Assert: Verify the expected outcomes
assert result is not None
assert result["role"] == TenantAccountRole.ADMIN.value
assert result["role"] == TenantAccountRole.ADMIN
# Verify custom config is included for admin users
assert "custom_config" in result
@ -378,7 +378,7 @@ class TestWorkspaceService:
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join.role = TenantAccountRole.EDITOR.value
join.role = TenantAccountRole.EDITOR
db.session.commit()
# Setup mocks for feature service and tenant service
@ -394,7 +394,7 @@ class TestWorkspaceService:
# Assert: Verify the expected outcomes
assert result is not None
assert result["role"] == TenantAccountRole.EDITOR.value
assert result["role"] == TenantAccountRole.EDITOR
# Verify custom config is not included for editor users without admin privileges
assert "custom_config" not in result
@ -425,7 +425,7 @@ class TestWorkspaceService:
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join.role = TenantAccountRole.DATASET_OPERATOR.value
join.role = TenantAccountRole.DATASET_OPERATOR
db.session.commit()
# Setup mocks for feature service and tenant service
@ -441,7 +441,7 @@ class TestWorkspaceService:
# Assert: Verify the expected outcomes
assert result is not None
assert result["role"] == TenantAccountRole.DATASET_OPERATOR.value
assert result["role"] == TenantAccountRole.DATASET_OPERATOR
# Verify custom config is not included for dataset operators without admin privileges
assert "custom_config" not in result

View File

@ -72,7 +72,7 @@ class TestApiToolManageService:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)

View File

@ -81,7 +81,7 @@ class TestMCPToolManageService:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)

View File

@ -168,7 +168,7 @@ class TestToolTransformService:
"""
# Arrange: Setup test data
fake = Faker()
provider_type = ToolProviderType.BUILT_IN.value
provider_type = ToolProviderType.BUILT_IN
provider_name = fake.company()
icon = "🔧"
@ -206,7 +206,7 @@ class TestToolTransformService:
"""
# Arrange: Setup test data
fake = Faker()
provider_type = ToolProviderType.API.value
provider_type = ToolProviderType.API
provider_name = fake.company()
icon = '{"background": "#FF6B6B", "content": "🔧"}'
@ -231,7 +231,7 @@ class TestToolTransformService:
"""
# Arrange: Setup test data with invalid JSON
fake = Faker()
provider_type = ToolProviderType.API.value
provider_type = ToolProviderType.API
provider_name = fake.company()
icon = '{"invalid": json}'
@ -257,7 +257,7 @@ class TestToolTransformService:
"""
# Arrange: Setup test data
fake = Faker()
provider_type = ToolProviderType.WORKFLOW.value
provider_type = ToolProviderType.WORKFLOW
provider_name = fake.company()
icon = {"background": "#FF6B6B", "content": "🔧"}
@ -282,7 +282,7 @@ class TestToolTransformService:
"""
# Arrange: Setup test data
fake = Faker()
provider_type = ToolProviderType.MCP.value
provider_type = ToolProviderType.MCP
provider_name = fake.company()
icon = {"background": "#FF6B6B", "content": "🔧"}
@ -329,7 +329,7 @@ class TestToolTransformService:
# Arrange: Setup test data
fake = Faker()
tenant_id = fake.uuid4()
provider = {"type": ToolProviderType.BUILT_IN.value, "name": fake.company(), "icon": "🔧"}
provider = {"type": ToolProviderType.BUILT_IN, "name": fake.company(), "icon": "🔧"}
# Act: Execute the method under test
ToolTransformService.repack_provider(tenant_id, provider)

View File

@ -66,7 +66,7 @@ class TestWorkflowConverter:
mock_config.model = ModelConfigEntity(
provider="openai",
model="gpt-4",
mode=LLMMode.CHAT.value,
mode=LLMMode.CHAT,
parameters={},
stop=[],
)
@ -120,7 +120,7 @@ class TestWorkflowConverter:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
@ -150,7 +150,7 @@ class TestWorkflowConverter:
app = App(
tenant_id=tenant.id,
name=fake.company(),
mode=AppMode.CHAT.value,
mode=AppMode.CHAT,
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
@ -218,7 +218,7 @@ class TestWorkflowConverter:
# Assert: Verify the expected outcomes
assert new_app is not None
assert new_app.name == "Test Workflow App"
assert new_app.mode == AppMode.ADVANCED_CHAT.value
assert new_app.mode == AppMode.ADVANCED_CHAT
assert new_app.icon_type == "emoji"
assert new_app.icon == "🚀"
assert new_app.icon_background == "#4CAF50"
@ -257,7 +257,7 @@ class TestWorkflowConverter:
app = App(
tenant_id=tenant.id,
name=fake.company(),
mode=AppMode.CHAT.value,
mode=AppMode.CHAT,
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
@ -522,7 +522,7 @@ class TestWorkflowConverter:
model_config = ModelConfigEntity(
provider="openai",
model="gpt-4",
mode=LLMMode.CHAT.value,
mode=LLMMode.CHAT,
parameters={"temperature": 0.7},
stop=[],
)

View File

@ -63,7 +63,7 @@ class TestAddDocumentToIndexTask:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)

View File

@ -84,7 +84,7 @@ class TestBatchCleanDocumentTask:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)

View File

@ -112,7 +112,7 @@ class TestBatchCreateSegmentToIndexTask:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)

View File

@ -91,7 +91,7 @@ class TestCreateSegmentToIndexTask:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)

View File

@ -48,11 +48,8 @@ class TestDeleteSegmentFromIndexTask:
Tenant: Created test tenant instance
"""
fake = fake or Faker()
tenant = Tenant()
tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="active")
tenant.id = fake.uuid4()
tenant.name = f"Test Tenant {fake.company()}"
tenant.plan = "basic"
tenant.status = "active"
tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at
@ -73,16 +70,14 @@ class TestDeleteSegmentFromIndexTask:
Account: Created test account instance
"""
fake = fake or Faker()
account = Account()
account = Account(
name=fake.name(),
email=fake.email(),
avatar=fake.url(),
status="active",
interface_language="en-US",
)
account.id = fake.uuid4()
account.email = fake.email()
account.name = fake.name()
account.avatar_url = fake.url()
account.tenant_id = tenant.id
account.status = "active"
account.type = "normal"
account.role = "owner"
account.interface_language = "en-US"
account.created_at = fake.date_time_this_year()
account.updated_at = account.created_at

View File

@ -69,7 +69,7 @@ class TestDisableSegmentFromIndexTask:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)

View File

@ -43,27 +43,30 @@ class TestDisableSegmentsFromIndexTask:
Account: Created test account instance
"""
fake = fake or Faker()
account = Account()
account = Account(
email=fake.email(),
name=fake.name(),
avatar=fake.url(),
status="active",
interface_language="en-US",
)
account.id = fake.uuid4()
account.email = fake.email()
account.name = fake.name()
account.avatar_url = fake.url()
# monkey-patch attributes for test setup
account.tenant_id = fake.uuid4()
account.status = "active"
account.type = "normal"
account.role = "owner"
account.interface_language = "en-US"
account.created_at = fake.date_time_this_year()
account.updated_at = account.created_at
# Create a tenant for the account
from models.account import Tenant
tenant = Tenant()
tenant = Tenant(
name=f"Test Tenant {fake.company()}",
plan="basic",
status="active",
)
tenant.id = account.tenant_id
tenant.name = f"Test Tenant {fake.company()}"
tenant.plan = "basic"
tenant.status = "active"
tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at
@ -91,20 +94,21 @@ class TestDisableSegmentsFromIndexTask:
Dataset: Created test dataset instance
"""
fake = fake or Faker()
dataset = Dataset()
dataset.id = fake.uuid4()
dataset.tenant_id = account.tenant_id
dataset.name = f"Test Dataset {fake.word()}"
dataset.description = fake.text(max_nb_chars=200)
dataset.provider = "vendor"
dataset.permission = "only_me"
dataset.data_source_type = "upload_file"
dataset.indexing_technique = "high_quality"
dataset.created_by = account.id
dataset.updated_by = account.id
dataset.embedding_model = "text-embedding-ada-002"
dataset.embedding_model_provider = "openai"
dataset.built_in_field_enabled = False
dataset = Dataset(
id=fake.uuid4(),
tenant_id=account.tenant_id,
name=f"Test Dataset {fake.word()}",
description=fake.text(max_nb_chars=200),
provider="vendor",
permission="only_me",
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=account.id,
updated_by=account.id,
embedding_model="text-embedding-ada-002",
embedding_model_provider="openai",
built_in_field_enabled=False,
)
from extensions.ext_database import db
@ -128,6 +132,7 @@ class TestDisableSegmentsFromIndexTask:
"""
fake = fake or Faker()
document = DatasetDocument()
document.id = fake.uuid4()
document.tenant_id = dataset.tenant_id
document.dataset_id = dataset.id
@ -153,7 +158,6 @@ class TestDisableSegmentsFromIndexTask:
document.archived = False
document.doc_form = "text_model" # Use text_model form for testing
document.doc_language = "en"
from extensions.ext_database import db
db.session.add(document)

View File

@ -72,7 +72,7 @@ class TestDocumentIndexingTask:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
@ -154,7 +154,7 @@ class TestDocumentIndexingTask:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)

View File

@ -63,7 +63,7 @@ class TestEnableSegmentsToIndexTask:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)

View File

@ -66,7 +66,7 @@ class TestMailAccountDeletionTask:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)

View File

@ -65,7 +65,7 @@ class TestMailChangeMailTask:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)

View File

@ -95,10 +95,10 @@ class TestMailInviteMemberTask:
name=fake.name(),
password=fake.password(),
interface_language="en-US",
status=AccountStatus.ACTIVE.value,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
status=AccountStatus.ACTIVE,
)
account.created_at = datetime.now(UTC)
account.updated_at = datetime.now(UTC)
db_session_with_containers.add(account)
db_session_with_containers.commit()
db_session_with_containers.refresh(account)
@ -106,9 +106,9 @@ class TestMailInviteMemberTask:
# Create tenant
tenant = Tenant(
name=fake.company(),
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
tenant.created_at = datetime.now(UTC)
tenant.updated_at = datetime.now(UTC)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
db_session_with_containers.refresh(tenant)
@ -117,9 +117,9 @@ class TestMailInviteMemberTask:
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
created_at=datetime.now(UTC),
role=TenantAccountRole.OWNER,
)
tenant_join.created_at = datetime.now(UTC)
db_session_with_containers.add(tenant_join)
db_session_with_containers.commit()
@ -163,10 +163,11 @@ class TestMailInviteMemberTask:
name=email.split("@")[0],
password="",
interface_language="en-US",
status=AccountStatus.PENDING.value,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
status=AccountStatus.PENDING,
)
account.created_at = datetime.now(UTC)
account.updated_at = datetime.now(UTC)
db_session_with_containers.add(account)
db_session_with_containers.commit()
db_session_with_containers.refresh(account)
@ -175,9 +176,9 @@ class TestMailInviteMemberTask:
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.NORMAL.value,
created_at=datetime.now(UTC),
role=TenantAccountRole.NORMAL,
)
tenant_join.created_at = datetime.now(UTC)
db_session_with_containers.add(tenant_join)
db_session_with_containers.commit()
@ -485,7 +486,7 @@ class TestMailInviteMemberTask:
db_session_with_containers.refresh(pending_account)
db_session_with_containers.refresh(tenant)
assert pending_account.status == AccountStatus.PENDING.value
assert pending_account.status == AccountStatus.PENDING
assert pending_account.email == invitee_email
assert tenant.name is not None
@ -496,7 +497,7 @@ class TestMailInviteMemberTask:
.first()
)
assert tenant_join is not None
assert tenant_join.role == TenantAccountRole.NORMAL.value
assert tenant_join.role == TenantAccountRole.NORMAL
def test_send_invite_member_mail_token_lifecycle_management(
self, db_session_with_containers, mock_external_service_dependencies

View File

@ -143,7 +143,7 @@ class TestOAuthCallback:
oauth_provider.get_user_info.return_value = OAuthUserInfo(id="123", name="Test User", email="test@example.com")
account = MagicMock()
account.status = AccountStatus.ACTIVE.value
account.status = AccountStatus.ACTIVE
token_pair = MagicMock()
token_pair.access_token = "jwt_access_token"
@ -220,11 +220,11 @@ class TestOAuthCallback:
@pytest.mark.parametrize(
("account_status", "expected_redirect"),
[
(AccountStatus.BANNED.value, "http://localhost:3000/signin?message=Account is banned."),
(AccountStatus.BANNED, "http://localhost:3000/signin?message=Account is banned."),
# CLOSED status: Currently NOT handled, will proceed to login (security issue)
# This documents actual behavior. See test_defensive_check_for_closed_account_status for details
(
AccountStatus.CLOSED.value,
AccountStatus.CLOSED,
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token",
),
],
@ -296,13 +296,13 @@ class TestOAuthCallback:
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
mock_account = MagicMock()
mock_account.status = AccountStatus.PENDING.value
mock_account.status = AccountStatus.PENDING
mock_generate_account.return_value = mock_account
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
resource.get("github")
assert mock_account.status == AccountStatus.ACTIVE.value
assert mock_account.status == AccountStatus.ACTIVE
assert mock_account.initialized_at is not None
mock_db.session.commit.assert_called_once()
@ -352,7 +352,7 @@ class TestOAuthCallback:
# Create account with CLOSED status
closed_account = MagicMock()
closed_account.status = AccountStatus.CLOSED.value
closed_account.status = AccountStatus.CLOSED
closed_account.id = "123"
closed_account.name = "Closed Account"
mock_generate_account.return_value = closed_account

View File

@ -60,7 +60,7 @@ class TestAccountInitialization:
return "success"
# Act
with patch("controllers.console.wraps.current_user", mock_user):
with patch("controllers.console.wraps._current_account", return_value=mock_user):
result = protected_view()
# Assert
@ -77,7 +77,7 @@ class TestAccountInitialization:
return "success"
# Act & Assert
with patch("controllers.console.wraps.current_user", mock_user):
with patch("controllers.console.wraps._current_account", return_value=mock_user):
with pytest.raises(AccountNotInitializedError):
protected_view()
@ -163,7 +163,7 @@ class TestBillingResourceLimits:
return "member_added"
# Act
with patch("controllers.console.wraps.current_user"):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = add_member()
@ -185,7 +185,7 @@ class TestBillingResourceLimits:
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
add_member()
@ -207,7 +207,7 @@ class TestBillingResourceLimits:
# Test 1: Should reject when source is datasets
with app.test_request_context("/?source=datasets"):
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
upload_document()
@ -215,7 +215,7 @@ class TestBillingResourceLimits:
# Test 2: Should allow when source is not datasets
with app.test_request_context("/?source=other"):
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = upload_document()
assert result == "document_uploaded"
@ -239,7 +239,7 @@ class TestRateLimiting:
return "knowledge_success"
# Act
with patch("controllers.console.wraps.current_user"):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):
@ -271,7 +271,7 @@ class TestRateLimiting:
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):

View File

@ -0,0 +1,722 @@
import json
import unittest
from unittest.mock import MagicMock, patch
import pytest
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
AlibabaCloudMySQLVector,
AlibabaCloudMySQLVectorConfig,
)
from core.rag.models.document import Document
try:
from mysql.connector import Error as MySQLError
except ImportError:
# Fallback for testing environments where mysql-connector-python might not be installed
class MySQLError(Exception):
def __init__(self, errno, msg):
self.errno = errno
self.msg = msg
super().__init__(msg)
class TestAlibabaCloudMySQLVector(unittest.TestCase):
def setUp(self):
self.config = AlibabaCloudMySQLVectorConfig(
host="localhost",
port=3306,
user="test_user",
password="test_password",
database="test_db",
max_connection=5,
charset="utf8mb4",
)
self.collection_name = "test_collection"
# Sample documents for testing
self.sample_documents = [
Document(
page_content="This is a test document about AI.",
metadata={"doc_id": "doc1", "document_id": "dataset1", "source": "test"},
),
Document(
page_content="Another document about machine learning.",
metadata={"doc_id": "doc2", "document_id": "dataset1", "source": "test"},
),
]
# Sample embeddings
self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_init(self, mock_pool_class):
"""Test AlibabaCloudMySQLVector initialization."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor for vector support check
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [
{"VERSION()": "8.0.36"}, # Version check
{"vector_support": True}, # Vector support check
]
alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config)
assert alibabacloud_mysql_vector.collection_name == self.collection_name
assert alibabacloud_mysql_vector.table_name == self.collection_name.lower()
assert alibabacloud_mysql_vector.get_type() == "alibabacloud_mysql"
assert alibabacloud_mysql_vector.distance_function == "cosine"
assert alibabacloud_mysql_vector.pool is not None
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
def test_create_collection(self, mock_redis, mock_pool_class):
"""Test collection creation."""
# Mock Redis operations
mock_redis.lock.return_value.__enter__ = MagicMock()
mock_redis.lock.return_value.__exit__ = MagicMock()
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [
{"VERSION()": "8.0.36"}, # Version check
{"vector_support": True}, # Vector support check
]
alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config)
alibabacloud_mysql_vector._create_collection(768)
# Verify SQL execution calls - should include table creation and index creation
assert mock_cursor.execute.called
assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes
mock_redis.set.assert_called_once()
def test_config_validation(self):
"""Test configuration validation."""
# Test missing required fields
with pytest.raises(ValueError):
AlibabaCloudMySQLVectorConfig(
host="", # Empty host should raise error
port=3306,
user="test",
password="test",
database="test",
max_connection=5,
)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_vector_support_check_success(self, mock_pool_class):
"""Test successful vector support check."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
# Should not raise an exception
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
assert vector_store is not None
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_vector_support_check_failure(self, mock_pool_class):
"""Test vector support check failure."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.35"}, {"vector_support": False}]
with pytest.raises(ValueError) as context:
AlibabaCloudMySQLVector(self.collection_name, self.config)
assert "RDS MySQL Vector functions are not available" in str(context.value)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_vector_support_check_function_error(self, mock_pool_class):
"""Test vector support check with function not found error."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = {"VERSION()": "8.0.36"}
mock_cursor.execute.side_effect = [None, MySQLError(errno=1305, msg="FUNCTION VEC_FromText does not exist")]
with pytest.raises(ValueError) as context:
AlibabaCloudMySQLVector(self.collection_name, self.config)
assert "RDS MySQL Vector functions are not available" in str(context.value)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
def test_create_documents(self, mock_redis, mock_pool_class):
"""Test creating documents with embeddings."""
# Setup mocks
self._setup_mocks(mock_redis, mock_pool_class)
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
result = vector_store.create(self.sample_documents, self.sample_embeddings)
assert len(result) == 2
assert "doc1" in result
assert "doc2" in result
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_add_texts(self, mock_pool_class):
"""Test adding texts to the vector store."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
result = vector_store.add_texts(self.sample_documents, self.sample_embeddings)
assert len(result) == 2
mock_cursor.executemany.assert_called_once()
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_text_exists(self, mock_pool_class):
"""Test checking if text exists."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [
{"VERSION()": "8.0.36"},
{"vector_support": True},
{"id": "doc1"}, # Text exists
]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
exists = vector_store.text_exists("doc1")
assert exists
# Check that the correct SQL was executed (last call after init)
execute_calls = mock_cursor.execute.call_args_list
last_call = execute_calls[-1]
assert "SELECT id FROM" in last_call[0][0]
assert last_call[0][1] == ("doc1",)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_text_not_exists(self, mock_pool_class):
"""Test checking if text does not exist."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [
{"VERSION()": "8.0.36"},
{"vector_support": True},
None, # Text does not exist
]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
exists = vector_store.text_exists("nonexistent")
assert not exists
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_get_by_ids(self, mock_pool_class):
"""Test getting documents by IDs."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter(
[
{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1"},
{"meta": json.dumps({"doc_id": "doc2", "source": "test"}), "text": "Test document 2"},
]
)
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
docs = vector_store.get_by_ids(["doc1", "doc2"])
assert len(docs) == 2
assert docs[0].page_content == "Test document 1"
assert docs[1].page_content == "Test document 2"
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_get_by_ids_empty_list(self, mock_pool_class):
"""Test getting documents with empty ID list."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
docs = vector_store.get_by_ids([])
assert len(docs) == 0
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_ids(self, mock_pool_class):
"""Test deleting documents by IDs."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
vector_store.delete_by_ids(["doc1", "doc2"])
# Check that delete SQL was executed
execute_calls = mock_cursor.execute.call_args_list
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
assert len(delete_calls) == 1
delete_call = delete_calls[0]
assert "DELETE FROM" in delete_call[0][0]
assert delete_call[0][1] == ["doc1", "doc2"]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_ids_empty_list(self, mock_pool_class):
"""Test deleting with empty ID list."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
vector_store.delete_by_ids([]) # Should not raise an exception
# Verify no delete SQL was executed
execute_calls = mock_cursor.execute.call_args_list
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
assert len(delete_calls) == 0
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_ids_table_not_exists(self, mock_pool_class):
"""Test deleting when table doesn't exist."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
# Simulate table doesn't exist error on delete
def execute_side_effect(*args, **kwargs):
if "DELETE" in args[0]:
raise MySQLError(errno=1146, msg="Table doesn't exist")
mock_cursor.execute.side_effect = execute_side_effect
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
# Should not raise an exception
vector_store.delete_by_ids(["doc1"])
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_metadata_field(self, mock_pool_class):
"""Test deleting documents by metadata field."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
vector_store.delete_by_metadata_field("document_id", "dataset1")
# Check that the correct SQL was executed
execute_calls = mock_cursor.execute.call_args_list
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
assert len(delete_calls) == 1
delete_call = delete_calls[0]
assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0]
assert delete_call[0][1] == ("$.document_id", "dataset1")
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_cosine(self, mock_pool_class):
"""Test vector search with cosine distance."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter(
[{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 0.1}]
)
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
query_vector = [0.1, 0.2, 0.3, 0.4]
docs = vector_store.search_by_vector(query_vector, top_k=5)
assert len(docs) == 1
assert docs[0].page_content == "Test document 1"
assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9
assert docs[0].metadata["distance"] == 0.1
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_euclidean(self, mock_pool_class):
"""Test vector search with euclidean distance."""
config = AlibabaCloudMySQLVectorConfig(
host="localhost",
port=3306,
user="test_user",
password="test_password",
database="test_db",
max_connection=5,
distance_function="euclidean",
)
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter(
[{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 2.0}]
)
vector_store = AlibabaCloudMySQLVector(self.collection_name, config)
query_vector = [0.1, 0.2, 0.3, 0.4]
docs = vector_store.search_by_vector(query_vector, top_k=5)
assert len(docs) == 1
assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_with_filter(self, mock_pool_class):
"""Test vector search with document ID filter."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter([])
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
query_vector = [0.1, 0.2, 0.3, 0.4]
docs = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["dataset1"])
# Verify the SQL contains the WHERE clause for filtering
execute_calls = mock_cursor.execute.call_args_list
search_calls = [call for call in execute_calls if "VEC_DISTANCE" in str(call)]
assert len(search_calls) > 0
search_call = search_calls[0]
assert "WHERE JSON_UNQUOTE" in search_call[0][0]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_with_score_threshold(self, mock_pool_class):
"""Test vector search with score threshold."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter(
[
{
"meta": json.dumps({"doc_id": "doc1", "source": "test"}),
"text": "High similarity document",
"distance": 0.1, # High similarity (score = 0.9)
},
{
"meta": json.dumps({"doc_id": "doc2", "source": "test"}),
"text": "Low similarity document",
"distance": 0.8, # Low similarity (score = 0.2)
},
]
)
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
query_vector = [0.1, 0.2, 0.3, 0.4]
docs = vector_store.search_by_vector(query_vector, top_k=5, score_threshold=0.5)
# Only the high similarity document should be returned
assert len(docs) == 1
assert docs[0].page_content == "High similarity document"
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_invalid_top_k(self, mock_pool_class):
"""Test vector search with invalid top_k."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
query_vector = [0.1, 0.2, 0.3, 0.4]
with pytest.raises(ValueError):
vector_store.search_by_vector(query_vector, top_k=0)
with pytest.raises(ValueError):
vector_store.search_by_vector(query_vector, top_k="invalid")
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_full_text(self, mock_pool_class):
"""Test full-text search."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter(
[
{
"meta": {"doc_id": "doc1", "source": "test"},
"text": "This document contains machine learning content",
"score": 1.5,
}
]
)
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
docs = vector_store.search_by_full_text("machine learning", top_k=5)
assert len(docs) == 1
assert docs[0].page_content == "This document contains machine learning content"
assert docs[0].metadata["score"] == 1.5
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_full_text_with_filter(self, mock_pool_class):
"""Test full-text search with document ID filter."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter([])
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
docs = vector_store.search_by_full_text("machine learning", top_k=5, document_ids_filter=["dataset1"])
# Verify the SQL contains the AND clause for filtering
execute_calls = mock_cursor.execute.call_args_list
search_calls = [call for call in execute_calls if "MATCH" in str(call)]
assert len(search_calls) > 0
search_call = search_calls[0]
assert "AND JSON_UNQUOTE" in search_call[0][0]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_full_text_invalid_top_k(self, mock_pool_class):
"""Test full-text search with invalid top_k."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
with pytest.raises(ValueError):
vector_store.search_by_full_text("test", top_k=0)
with pytest.raises(ValueError):
vector_store.search_by_full_text("test", top_k="invalid")
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_collection(self, mock_pool_class):
"""Test deleting the entire collection."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
vector_store.delete()
# Check that DROP TABLE SQL was executed
execute_calls = mock_cursor.execute.call_args_list
drop_calls = [call for call in execute_calls if "DROP TABLE" in str(call)]
assert len(drop_calls) == 1
drop_call = drop_calls[0]
assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_unsupported_distance_function(self, mock_pool_class):
"""Test that Pydantic validation rejects unsupported distance functions."""
# Test that creating config with unsupported distance function raises ValidationError
with pytest.raises(ValueError) as context:
AlibabaCloudMySQLVectorConfig(
host="localhost",
port=3306,
user="test_user",
password="test_password",
database="test_db",
max_connection=5,
distance_function="manhattan", # Unsupported - not in Literal["cosine", "euclidean"]
)
# The error should be related to validation
assert "Input should be 'cosine' or 'euclidean'" in str(context.value) or "manhattan" in str(context.value)
def _setup_mocks(self, mock_redis, mock_pool_class):
"""Helper method to setup common mocks."""
# Mock Redis operations
mock_redis.lock.return_value.__enter__ = MagicMock()
mock_redis.lock.return_value.__exit__ = MagicMock()
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
if __name__ == "__main__":
unittest.main()

View File

@ -11,8 +11,8 @@ def test_default_value():
config = valid_config.copy()
del config[key]
with pytest.raises(ValidationError) as e:
MilvusConfig(**config)
MilvusConfig.model_validate(config)
assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required"
config = MilvusConfig(**valid_config)
config = MilvusConfig.model_validate(valid_config)
assert config.database == "default"

View File

@ -1,10 +1,12 @@
import os
from pytest_mock import MockerFixture
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
def test_firecrawl_web_extractor_crawl_mode(mocker):
def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
url = "https://firecrawl.dev"
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
base_url = "https://api.firecrawl.dev"
@ -18,7 +20,7 @@ def test_firecrawl_web_extractor_crawl_mode(mocker):
mocked_firecrawl = {
"id": "test",
}
mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl))
mocker.patch("httpx.post", return_value=_mock_response(mocked_firecrawl))
job_id = firecrawl_app.crawl_url(url, params)
assert job_id is not None

View File

@ -1,5 +1,7 @@
from unittest import mock
from pytest_mock import MockerFixture
from core.rag.extractor import notion_extractor
user_id = "user1"
@ -57,7 +59,7 @@ def _remove_multiple_new_lines(text):
return text.strip()
def test_notion_page(mocker):
def test_notion_page(mocker: MockerFixture):
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
mocked_notion_page = {
"object": "list",
@ -69,7 +71,7 @@ def test_notion_page(mocker):
],
"next_cursor": None,
}
mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page))
mocker.patch("httpx.request", return_value=_mock_response(mocked_notion_page))
page_docs = extractor._load_data_as_documents(page_id, "page")
assert len(page_docs) == 1
@ -77,14 +79,14 @@ def test_notion_page(mocker):
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
def test_notion_database(mocker):
def test_notion_database(mocker: MockerFixture):
page_title_list = ["page1", "page2", "page3"]
mocked_notion_database = {
"object": "list",
"results": [_generate_page(i) for i in page_title_list],
"next_cursor": None,
}
mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database))
mocker.patch("httpx.post", return_value=_mock_response(mocked_notion_database))
database_docs = extractor._load_data_as_documents(database_id, "database")
assert len(database_docs) == 1
content = _remove_multiple_new_lines(database_docs[0].page_content)

View File

@ -140,7 +140,7 @@ class TestCeleryWorkflowExecutionRepository:
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN
assert call_args["creator_user_id"] == mock_account.id
# Verify no task tracking occurs (no _pending_saves attribute)

View File

@ -149,7 +149,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
assert call_args["execution_data"] == sample_workflow_node_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
assert call_args["creator_user_id"] == mock_account.id
# Verify execution is cached

View File

@ -145,12 +145,12 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryTruncation:
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "node-id"
db_model.node_type = NodeType.LLM.value
db_model.node_type = NodeType.LLM
db_model.title = "Test Node"
db_model.inputs = json.dumps({"value": "inputs"})
db_model.process_data = json.dumps({"value": "process_data"})
db_model.outputs = json.dumps({"value": "outputs"})
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
db_model.error = None
db_model.elapsed_time = 1.0
db_model.execution_metadata = "{}"

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
import redis
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.model_manager import LBModelManager
@ -39,7 +40,7 @@ def lb_model_manager():
return lb_model_manager
def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
# initialize redis client
redis_client.initialize(redis.Redis())

View File

@ -14,7 +14,13 @@ from core.entities.provider_entities import (
)
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
CredentialFormSchema,
FormOption,
FormType,
ProviderEntity,
)
from models.provider import Provider, ProviderType
@ -306,3 +312,174 @@ class TestProviderConfiguration:
# Assert
assert credentials == {"openai_api_key": "test_key"}
def test_extract_secret_variables_with_secret_input(self, provider_configuration):
"""Test extracting secret variables from credential form schemas"""
# Arrange
credential_form_schemas = [
CredentialFormSchema(
variable="api_key",
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
type=FormType.SECRET_INPUT,
required=True,
),
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="secret_token",
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
type=FormType.SECRET_INPUT,
required=False,
),
]
# Act
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
# Assert
assert len(secret_variables) == 2
assert "api_key" in secret_variables
assert "secret_token" in secret_variables
assert "model_name" not in secret_variables
def test_extract_secret_variables_no_secret_input(self, provider_configuration):
"""Test extracting secret variables when no secret input fields exist"""
# Arrange
credential_form_schemas = [
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=FormType.SELECT,
required=True,
options=[FormOption(label=I18nObject(en_US="0.1", zh_Hans="0.1"), value="0.1")],
),
]
# Act
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
# Assert
assert len(secret_variables) == 0
def test_extract_secret_variables_empty_list(self, provider_configuration):
"""Test extracting secret variables from empty credential form schemas"""
# Arrange
credential_form_schemas = []
# Act
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
# Assert
assert len(secret_variables) == 0
@patch("core.entities.provider_configuration.encrypter")
def test_obfuscated_credentials_with_secret_variables(self, mock_encrypter, provider_configuration):
"""Test obfuscating credentials with secret variables"""
# Arrange
credentials = {
"api_key": "sk-1234567890abcdef",
"model_name": "gpt-4",
"secret_token": "secret_value_123",
"temperature": "0.7",
}
credential_form_schemas = [
CredentialFormSchema(
variable="api_key",
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
type=FormType.SECRET_INPUT,
required=True,
),
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="secret_token",
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
type=FormType.SECRET_INPUT,
required=False,
),
CredentialFormSchema(
variable="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=FormType.TEXT_INPUT,
required=True,
),
]
mock_encrypter.obfuscated_token.side_effect = lambda x: f"***{x[-4:]}"
# Act
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
# Assert
assert obfuscated["api_key"] == "***cdef"
assert obfuscated["model_name"] == "gpt-4" # Not obfuscated
assert obfuscated["secret_token"] == "***_123"
assert obfuscated["temperature"] == "0.7" # Not obfuscated
# Verify encrypter was called for secret fields only
assert mock_encrypter.obfuscated_token.call_count == 2
mock_encrypter.obfuscated_token.assert_any_call("sk-1234567890abcdef")
mock_encrypter.obfuscated_token.assert_any_call("secret_value_123")
def test_obfuscated_credentials_no_secret_variables(self, provider_configuration):
"""Test obfuscating credentials when no secret variables exist"""
# Arrange
credentials = {
"model_name": "gpt-4",
"temperature": "0.7",
"max_tokens": "1000",
}
credential_form_schemas = [
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="max_tokens",
label=I18nObject(en_US="Max Tokens", zh_Hans="最大令牌数"),
type=FormType.TEXT_INPUT,
required=True,
),
]
# Act
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
# Assert
assert obfuscated == credentials # No changes expected
def test_obfuscated_credentials_empty_credentials(self, provider_configuration):
"""Test obfuscating empty credentials"""
# Arrange
credentials = {}
credential_form_schemas = []
# Act
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
# Assert
assert obfuscated == {}

View File

@ -1,4 +1,5 @@
import pytest
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelSettings
from core.model_runtime.entities.model_entities import ModelType
@ -7,19 +8,25 @@ from models.provider import LoadBalancingModelConfig, ProviderModelSetting
@pytest.fixture
def mock_provider_entity(mocker):
def mock_provider_entity(mocker: MockerFixture):
mock_entity = mocker.Mock()
mock_entity.provider = "openai"
mock_entity.configurate_methods = ["predefined-model"]
mock_entity.supported_model_types = [ModelType.LLM]
mock_entity.model_credential_schema = mocker.Mock()
mock_entity.model_credential_schema.credential_form_schemas = []
# Use PropertyMock to ensure credential_form_schemas is iterable
provider_credential_schema = mocker.Mock()
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
mock_entity.provider_credential_schema = provider_credential_schema
model_credential_schema = mocker.Mock()
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
mock_entity.model_credential_schema = model_credential_schema
return mock_entity
def test__to_model_settings(mocker, mock_provider_entity):
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(
@ -79,7 +86,7 @@ def test__to_model_settings(mocker, mock_provider_entity):
assert result[0].load_balancing_configs[1].name == "first"
def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(
@ -127,7 +134,7 @@ def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
assert len(result[0].load_balancing_configs) == 0
def test__to_model_settings_lb_disabled(mocker, mock_provider_entity):
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(

View File

@ -147,7 +147,7 @@ class TestRedisChannel:
"""Test deserializing an abort command."""
channel = RedisChannel(MagicMock(), "test:key")
abort_data = {"command_type": CommandType.ABORT.value}
abort_data = {"command_type": CommandType.ABORT}
command = channel._deserialize_command(abort_data)
assert isinstance(command, AbortCommand)
@ -158,7 +158,7 @@ class TestRedisChannel:
channel = RedisChannel(MagicMock(), "test:key")
# For now, only ABORT is supported, but test generic handling
generic_data = {"command_type": CommandType.ABORT.value}
generic_data = {"command_type": CommandType.ABORT}
command = channel._deserialize_command(generic_data)
assert command is not None

View File

@ -56,8 +56,8 @@ def test_mock_iteration_node_preserves_config():
workflow_id="test",
graph_config={"nodes": [], "edges": []},
user_id="test",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
@ -117,8 +117,8 @@ def test_mock_loop_node_preserves_config():
workflow_id="test",
graph_config={"nodes": [], "edges": []},
user_id="test",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)

View File

@ -49,7 +49,7 @@ class TestRedisStopIntegration:
# Verify the command data
command_json = calls[0][0][1]
command_data = json.loads(command_json)
assert command_data["command_type"] == CommandType.ABORT.value
assert command_data["command_type"] == CommandType.ABORT
assert command_data["reason"] == "Test stop"
def test_graph_engine_manager_handles_redis_failure_gracefully(self):
@ -122,7 +122,7 @@ class TestRedisStopIntegration:
# Verify serialized command
command_json = calls[0][0][1]
command_data = json.loads(command_json)
assert command_data["command_type"] == CommandType.ABORT.value
assert command_data["command_type"] == CommandType.ABORT
assert command_data["reason"] == "User requested stop"
# Check expire was set
@ -137,9 +137,7 @@ class TestRedisStopIntegration:
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
# Mock command data
abort_command_json = json.dumps(
{"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None}
)
abort_command_json = json.dumps({"command_type": CommandType.ABORT, "reason": "Test abort", "payload": None})
# Mock pipeline execute to return commands
mock_pipeline.execute.return_value = [

View File

@ -35,7 +35,7 @@ def list_operator_node():
"extract_by": ExtractConfig(enabled=False, serial="1"),
"title": "Test Title",
}
node_data = ListOperatorNodeData(**config)
node_data = ListOperatorNodeData.model_validate(config)
node_config = {
"id": "test_node_id",
"data": node_data.model_dump(),

View File

@ -17,7 +17,7 @@ def test_init_question_classifier_node_data():
"vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}},
}
node_data = QuestionClassifierNodeData(**data)
node_data = QuestionClassifierNodeData.model_validate(data)
assert node_data.query_variable_selector == ["id", "name"]
assert node_data.model.provider == "openai"
@ -49,7 +49,7 @@ def test_init_question_classifier_node_data_without_vision_config():
},
}
node_data = QuestionClassifierNodeData(**data)
node_data = QuestionClassifierNodeData.model_validate(data)
assert node_data.query_variable_selector == ["id", "name"]
assert node_data.model.provider == "openai"

View File

@ -87,7 +87,7 @@ def test_overwrite_string_variable():
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.OVER_WRITE.value,
"write_mode": WriteMode.OVER_WRITE,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
}
@ -189,7 +189,7 @@ def test_append_variable_to_array():
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.APPEND.value,
"write_mode": WriteMode.APPEND,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
}
@ -282,7 +282,7 @@ def test_clear_array():
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.CLEAR.value,
"write_mode": WriteMode.CLEAR,
"input_variable_selector": [],
},
}

View File

@ -46,7 +46,7 @@ class TestSystemVariableSerialization:
def test_basic_deserialization(self):
"""Test successful deserialization from JSON structure with all fields correctly mapped."""
# Test with complete data
system_var = SystemVariable(**COMPLETE_VALID_DATA)
system_var = SystemVariable.model_validate(COMPLETE_VALID_DATA)
# Verify all fields are correctly mapped
assert system_var.user_id == COMPLETE_VALID_DATA["user_id"]
@ -59,7 +59,7 @@ class TestSystemVariableSerialization:
assert system_var.files == []
# Test with minimal data (only required fields)
minimal_var = SystemVariable(**VALID_BASE_DATA)
minimal_var = SystemVariable.model_validate(VALID_BASE_DATA)
assert minimal_var.user_id == VALID_BASE_DATA["user_id"]
assert minimal_var.app_id == VALID_BASE_DATA["app_id"]
assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"]
@ -75,12 +75,12 @@ class TestSystemVariableSerialization:
# Test workflow_run_id only (preferred alias)
data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
system_var1 = SystemVariable(**data_run_id)
system_var1 = SystemVariable.model_validate(data_run_id)
assert system_var1.workflow_execution_id == workflow_id
# Test workflow_execution_id only (direct field name)
data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
system_var2 = SystemVariable(**data_execution_id)
system_var2 = SystemVariable.model_validate(data_execution_id)
assert system_var2.workflow_execution_id == workflow_id
# Test both present - workflow_run_id should take precedence
@ -89,17 +89,17 @@ class TestSystemVariableSerialization:
"workflow_execution_id": "should-be-ignored",
"workflow_run_id": workflow_id,
}
system_var3 = SystemVariable(**data_both)
system_var3 = SystemVariable.model_validate(data_both)
assert system_var3.workflow_execution_id == workflow_id
# Test neither present - should be None
system_var4 = SystemVariable(**VALID_BASE_DATA)
system_var4 = SystemVariable.model_validate(VALID_BASE_DATA)
assert system_var4.workflow_execution_id is None
def test_serialization_round_trip(self):
"""Test that serialize → deserialize produces the same result with alias handling."""
# Create original SystemVariable
original = SystemVariable(**COMPLETE_VALID_DATA)
original = SystemVariable.model_validate(COMPLETE_VALID_DATA)
# Serialize to dict
serialized = original.model_dump(mode="json")
@ -110,7 +110,7 @@ class TestSystemVariableSerialization:
assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
# Deserialize back
deserialized = SystemVariable(**serialized)
deserialized = SystemVariable.model_validate(serialized)
# Verify all fields match after round-trip
assert deserialized.user_id == original.user_id
@ -125,7 +125,7 @@ class TestSystemVariableSerialization:
def test_json_round_trip(self):
"""Test JSON serialization/deserialization consistency with proper structure."""
# Create original SystemVariable
original = SystemVariable(**COMPLETE_VALID_DATA)
original = SystemVariable.model_validate(COMPLETE_VALID_DATA)
# Serialize to JSON string
json_str = original.model_dump_json()
@ -137,7 +137,7 @@ class TestSystemVariableSerialization:
assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
# Deserialize from JSON data
deserialized = SystemVariable(**json_data)
deserialized = SystemVariable.model_validate(json_data)
# Verify key fields match after JSON round-trip
assert deserialized.workflow_execution_id == original.workflow_execution_id
@ -149,13 +149,13 @@ class TestSystemVariableSerialization:
"""Test deserialization with File objects in the files field - SystemVariable specific logic."""
# Test with empty files list
data_empty = {**VALID_BASE_DATA, "files": []}
system_var_empty = SystemVariable(**data_empty)
system_var_empty = SystemVariable.model_validate(data_empty)
assert system_var_empty.files == []
# Test with single File object
test_file = create_test_file()
data_single = {**VALID_BASE_DATA, "files": [test_file]}
system_var_single = SystemVariable(**data_single)
system_var_single = SystemVariable.model_validate(data_single)
assert len(system_var_single.files) == 1
assert system_var_single.files[0].filename == "test.txt"
assert system_var_single.files[0].tenant_id == "test-tenant-id"
@ -179,14 +179,14 @@ class TestSystemVariableSerialization:
)
data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]}
system_var_multiple = SystemVariable(**data_multiple)
system_var_multiple = SystemVariable.model_validate(data_multiple)
assert len(system_var_multiple.files) == 2
assert system_var_multiple.files[0].filename == "doc1.txt"
assert system_var_multiple.files[1].filename == "image.jpg"
# Verify files field serialization/deserialization
serialized = system_var_multiple.model_dump(mode="json")
deserialized = SystemVariable(**serialized)
deserialized = SystemVariable.model_validate(serialized)
assert len(deserialized.files) == 2
assert deserialized.files[0].filename == "doc1.txt"
assert deserialized.files[1].filename == "image.jpg"
@ -197,7 +197,7 @@ class TestSystemVariableSerialization:
# Create with workflow_run_id (alias)
data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
system_var = SystemVariable(**data_with_alias)
system_var = SystemVariable.model_validate(data_with_alias)
# Serialize and verify alias is used
serialized = system_var.model_dump()
@ -205,7 +205,7 @@ class TestSystemVariableSerialization:
assert "workflow_execution_id" not in serialized
# Deserialize and verify field mapping
deserialized = SystemVariable(**serialized)
deserialized = SystemVariable.model_validate(serialized)
assert deserialized.workflow_execution_id == workflow_id
# Test JSON serialization path
@ -213,7 +213,7 @@ class TestSystemVariableSerialization:
assert json_serialized["workflow_run_id"] == workflow_id
assert "workflow_execution_id" not in json_serialized
json_deserialized = SystemVariable(**json_serialized)
json_deserialized = SystemVariable.model_validate(json_serialized)
assert json_deserialized.workflow_execution_id == workflow_id
def test_model_validator_serialization_logic(self):
@ -222,7 +222,7 @@ class TestSystemVariableSerialization:
# Test direct instantiation with workflow_execution_id (should work)
data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
system_var1 = SystemVariable(**data1)
system_var1 = SystemVariable.model_validate(data1)
assert system_var1.workflow_execution_id == workflow_id
# Test serialization of the above (should use alias)
@ -236,7 +236,7 @@ class TestSystemVariableSerialization:
"workflow_execution_id": "should-be-removed",
"workflow_run_id": workflow_id,
}
system_var2 = SystemVariable(**data2)
system_var2 = SystemVariable.model_validate(data2)
assert system_var2.workflow_execution_id == workflow_id
# Verify serialization consistency

View File

@ -11,7 +11,7 @@ class TestExtractTenantId:
def test_extract_tenant_id_from_account_with_tenant(self):
"""Test extracting tenant_id from Account with current_tenant_id."""
# Create a mock Account object
account = Account()
account = Account(name="test", email="test@example.com")
# Mock the current_tenant_id property
account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})()
@ -21,7 +21,7 @@ class TestExtractTenantId:
def test_extract_tenant_id_from_account_without_tenant(self):
"""Test extracting tenant_id from Account without current_tenant_id."""
# Create a mock Account object
account = Account()
account = Account(name="test", email="test@example.com")
account._current_tenant = None
tenant_id = extract_tenant_id(account)

View File

@ -59,12 +59,11 @@ def session():
@pytest.fixture
def mock_user():
"""Create a user instance for testing."""
user = Account()
user = Account(name="test", email="test@example.com")
user.id = "test-user-id"
tenant = Tenant()
tenant = Tenant(name="Test Workspace")
tenant.id = "test-tenant"
tenant.name = "Test Workspace"
user._current_tenant = MagicMock()
user._current_tenant.id = "test-tenant"
@ -299,7 +298,7 @@ def test_to_domain_model(repository):
db_model.predecessor_node_id = "test-predecessor-id"
db_model.node_execution_id = "test-node-execution-id"
db_model.node_id = "test-node-id"
db_model.node_type = NodeType.START.value
db_model.node_type = NodeType.START
db_model.title = "Test Node"
db_model.inputs = json.dumps(inputs_dict)
db_model.process_data = json.dumps(process_data_dict)

View File

@ -118,7 +118,7 @@ class TestMetadataBugCompleteValidation:
# But would crash when trying to create MetadataArgs
with pytest.raises((ValueError, TypeError)):
MetadataArgs(**args)
MetadataArgs.model_validate(args)
def test_7_end_to_end_validation_layers(self):
"""Test all validation layers work together correctly."""
@ -131,7 +131,7 @@ class TestMetadataBugCompleteValidation:
valid_data = {"type": "string", "name": "test_metadata"}
# Should create valid Pydantic object
metadata_args = MetadataArgs(**valid_data)
metadata_args = MetadataArgs.model_validate(valid_data)
assert metadata_args.type == "string"
assert metadata_args.name == "test_metadata"

View File

@ -76,7 +76,7 @@ class TestMetadataNullableBug:
# Step 2: Try to create MetadataArgs with None values
# This should fail at Pydantic validation level
with pytest.raises((ValueError, TypeError)):
metadata_args = MetadataArgs(**args)
metadata_args = MetadataArgs.model_validate(args)
# Step 3: If we bypass Pydantic (simulating the bug scenario)
# Move this outside the request context to avoid Flask-Login issues

View File

@ -107,7 +107,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
assert body_data
body_data_json = json.loads(body_data)
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
body_params = body_data_json["params"]
assert body_params["app_id"] == app_model.id
@ -168,7 +168,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
assert body_data
body_data_json = json.loads(body_data)
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
body_params = body_data_json["params"]
assert body_params["app_id"] == app_model.id

View File

@ -47,7 +47,8 @@ class TestDraftVariableSaver:
def test__should_variable_be_visible(self):
mock_session = MagicMock(spec=Session)
mock_user = Account(id=str(uuid.uuid4()))
mock_user = Account(name="test", email="test@example.com")
mock_user.id = str(uuid.uuid4())
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,