feat: replace db.session with db_session_with_containers (#32942)

This commit is contained in:
Renzo
2026-03-03 21:50:41 -08:00
committed by GitHub
parent 2f4c740d46
commit ad000c42b7
43 changed files with 3078 additions and 2669 deletions

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from pydantic import TypeAdapter, ValidationError
from sqlalchemy.orm import Session
from core.tools.entities.tool_entities import ApiProviderSchemaType
from models import Account, Tenant
@ -34,7 +35,7 @@ class TestApiToolManageService:
"provider_controller": mock_provider_controller,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
@ -55,18 +56,16 @@ class TestApiToolManageService:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
from models.account import TenantAccountJoin, TenantAccountRole
@ -77,8 +76,8 @@ class TestApiToolManageService:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
@ -118,7 +117,7 @@ class TestApiToolManageService:
"""
def test_parser_api_schema_success(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful parsing of API schema.
@ -163,7 +162,7 @@ class TestApiToolManageService:
assert api_key_value_field["default"] == ""
def test_parser_api_schema_invalid_schema(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test parsing of invalid API schema.
@ -183,7 +182,7 @@ class TestApiToolManageService:
assert "invalid schema" in str(exc_info.value)
def test_parser_api_schema_malformed_json(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test parsing of malformed JSON schema.
@ -203,7 +202,7 @@ class TestApiToolManageService:
assert "invalid schema" in str(exc_info.value)
def test_convert_schema_to_tool_bundles_success(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion of schema to tool bundles.
@ -233,7 +232,7 @@ class TestApiToolManageService:
assert tool_bundle.operation_id == "testOperation"
def test_convert_schema_to_tool_bundles_with_extra_info(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion of schema to tool bundles with extra info.
@ -259,7 +258,7 @@ class TestApiToolManageService:
assert isinstance(schema_type, str)
def test_convert_schema_to_tool_bundles_invalid_schema(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test conversion of invalid schema to tool bundles.
@ -279,7 +278,7 @@ class TestApiToolManageService:
assert "invalid schema" in str(exc_info.value)
def test_create_api_tool_provider_success(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful creation of API tool provider.
@ -324,10 +323,9 @@ class TestApiToolManageService:
assert result == {"result": "success"}
# Verify database state
from extensions.ext_database import db
provider = (
db.session.query(ApiToolProvider)
db_session_with_containers.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)
@ -347,7 +345,7 @@ class TestApiToolManageService:
mock_external_service_dependencies["provider_controller"].load_bundled_tools.assert_called_once()
def test_create_api_tool_provider_duplicate_name(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test creation of API tool provider with duplicate name.
@ -404,7 +402,7 @@ class TestApiToolManageService:
assert f"provider {provider_name} already exists" in str(exc_info.value)
def test_create_api_tool_provider_invalid_schema_type(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test creation of API tool provider with invalid schema type.
@ -436,7 +434,7 @@ class TestApiToolManageService:
assert "validation error" in str(exc_info.value)
def test_create_api_tool_provider_missing_auth_type(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test creation of API tool provider with missing auth type.
@ -479,7 +477,7 @@ class TestApiToolManageService:
assert "auth_type is required" in str(exc_info.value)
def test_create_api_tool_provider_with_api_key_auth(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful creation of API tool provider with API key authentication.
@ -522,10 +520,9 @@ class TestApiToolManageService:
assert result == {"result": "success"}
# Verify database state
from extensions.ext_database import db
provider = (
db.session.query(ApiToolProvider)
db_session_with_containers.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.tools.entities.tool_entities import ToolProviderType
from models import Account, Tenant
@ -41,7 +42,7 @@ class TestMCPToolManageService:
"tool_transform_service": mock_tool_transform_service,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
@ -62,18 +63,16 @@ class TestMCPToolManageService:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
from models.account import TenantAccountJoin, TenantAccountRole
@ -84,8 +83,8 @@ class TestMCPToolManageService:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
@ -93,7 +92,7 @@ class TestMCPToolManageService:
return account, tenant
def _create_test_mcp_provider(
self, db_session_with_containers, mock_external_service_dependencies, tenant_id, user_id
self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id, user_id
):
"""
Helper method to create a test MCP tool provider for testing.
@ -124,15 +123,13 @@ class TestMCPToolManageService:
sse_read_timeout=300.0,
)
from extensions.ext_database import db
db.session.add(mcp_provider)
db.session.commit()
db_session_with_containers.add(mcp_provider)
db_session_with_containers.commit()
return mcp_provider
def test_get_mcp_provider_by_provider_id_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of MCP provider by provider ID.
@ -153,9 +150,8 @@ class TestMCPToolManageService:
)
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
result = service.get_provider(provider_id=mcp_provider.id, tenant_id=tenant.id)
# Assert: Verify the expected outcomes
@ -166,12 +162,12 @@ class TestMCPToolManageService:
assert result.user_id == account.id
# Verify database state
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.id is not None
assert result.server_identifier == mcp_provider.server_identifier
def test_get_mcp_provider_by_provider_id_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when MCP provider is not found by provider ID.
@ -190,14 +186,13 @@ class TestMCPToolManageService:
non_existent_id = str(fake.uuid4())
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"):
service.get_provider(provider_id=non_existent_id, tenant_id=tenant.id)
def test_get_mcp_provider_by_provider_id_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant isolation when retrieving MCP provider by provider ID.
@ -223,14 +218,13 @@ class TestMCPToolManageService:
)
# Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"):
service.get_provider(provider_id=mcp_provider1.id, tenant_id=tenant2.id)
def test_get_mcp_provider_by_server_identifier_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of MCP provider by server identifier.
@ -251,9 +245,8 @@ class TestMCPToolManageService:
)
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
result = service.get_provider(server_identifier=mcp_provider.server_identifier, tenant_id=tenant.id)
# Assert: Verify the expected outcomes
@ -264,12 +257,12 @@ class TestMCPToolManageService:
assert result.user_id == account.id
# Verify database state
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.id is not None
assert result.name == mcp_provider.name
def test_get_mcp_provider_by_server_identifier_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when MCP provider is not found by server identifier.
@ -288,14 +281,13 @@ class TestMCPToolManageService:
non_existent_identifier = str(fake.uuid4())
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"):
service.get_provider(server_identifier=non_existent_identifier, tenant_id=tenant.id)
def test_get_mcp_provider_by_server_identifier_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant isolation when retrieving MCP provider by server identifier.
@ -321,13 +313,12 @@ class TestMCPToolManageService:
)
# Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"):
service.get_provider(server_identifier=mcp_provider1.server_identifier, tenant_id=tenant2.id)
def test_create_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful creation of MCP provider.
@ -365,9 +356,8 @@ class TestMCPToolManageService:
# Act: Execute the method under test
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
result = service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider",
@ -389,10 +379,9 @@ class TestMCPToolManageService:
assert result.type == ToolProviderType.MCP
# Verify database state
from extensions.ext_database import db
created_provider = (
db.session.query(MCPToolProvider)
db_session_with_containers.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant.id, MCPToolProvider.name == "Test MCP Provider")
.first()
)
@ -410,7 +399,9 @@ class TestMCPToolManageService:
)
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_called_once()
def test_create_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_mcp_provider_duplicate_name(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when creating MCP provider with duplicate name.
@ -427,9 +418,8 @@ class TestMCPToolManageService:
# Create first provider
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider",
@ -463,7 +453,7 @@ class TestMCPToolManageService:
)
def test_create_mcp_provider_duplicate_server_url(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when creating MCP provider with duplicate server URL.
@ -481,9 +471,8 @@ class TestMCPToolManageService:
# Create first provider
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 1",
@ -517,7 +506,7 @@ class TestMCPToolManageService:
)
def test_create_mcp_provider_duplicate_server_identifier(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when creating MCP provider with duplicate server identifier.
@ -535,9 +524,8 @@ class TestMCPToolManageService:
# Create first provider
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 1",
@ -570,7 +558,7 @@ class TestMCPToolManageService:
),
)
def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_retrieve_mcp_tools_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful retrieval of MCP tools for a tenant.
@ -602,9 +590,7 @@ class TestMCPToolManageService:
)
provider3.name = "Gamma Provider"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Setup mock for transformation service
from core.tools.entities.api_entities import ToolProviderApiEntity
@ -647,9 +633,8 @@ class TestMCPToolManageService:
]
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
result = service.list_providers(tenant_id=tenant.id, for_list=True)
# Assert: Verify the expected outcomes
@ -666,7 +651,9 @@ class TestMCPToolManageService:
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.call_count == 3
)
def test_retrieve_mcp_tools_empty_list(self, db_session_with_containers, mock_external_service_dependencies):
def test_retrieve_mcp_tools_empty_list(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test retrieval of MCP tools when tenant has no providers.
@ -684,9 +671,8 @@ class TestMCPToolManageService:
# No MCP providers created for this tenant
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
result = service.list_providers(tenant_id=tenant.id, for_list=False)
# Assert: Verify the expected outcomes
@ -697,7 +683,9 @@ class TestMCPToolManageService:
# Verify no transformation service calls for empty list
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_not_called()
def test_retrieve_mcp_tools_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies):
def test_retrieve_mcp_tools_tenant_isolation(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant isolation when retrieving MCP tools.
@ -756,9 +744,8 @@ class TestMCPToolManageService:
]
# Act: Execute the method under test for both tenants
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
result1 = service.list_providers(tenant_id=tenant1.id, for_list=True)
result2 = service.list_providers(tenant_id=tenant2.id, for_list=True)
@ -769,7 +756,7 @@ class TestMCPToolManageService:
assert result2[0].id == provider2.id
def test_list_mcp_tool_from_remote_server_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful listing of MCP tools from remote server.
@ -797,9 +784,7 @@ class TestMCPToolManageService:
mcp_provider.authed = True # Provider must be authenticated to list tools
mcp_provider.tools = "[]"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt:
@ -821,9 +806,8 @@ class TestMCPToolManageService:
mock_client_instance.list_tools.return_value = mock_tools
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
result = service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Assert: Verify the expected outcomes
@ -834,7 +818,7 @@ class TestMCPToolManageService:
# Note: server_url is mocked, so we skip that assertion to avoid encryption issues
# Verify database state was updated
db.session.refresh(mcp_provider)
db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.authed is True
assert mcp_provider.tools != "[]"
assert mcp_provider.updated_at is not None
@ -844,7 +828,7 @@ class TestMCPToolManageService:
mock_mcp_client.assert_called_once()
def test_list_mcp_tool_from_remote_server_auth_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when MCP server requires authentication.
@ -871,9 +855,7 @@ class TestMCPToolManageService:
mcp_provider.authed = False
mcp_provider.tools = "[]"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt:
@ -887,19 +869,18 @@ class TestMCPToolManageService:
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="Please auth the tool first"):
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Verify database state was not changed
db.session.refresh(mcp_provider)
db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.authed is False
assert mcp_provider.tools == "[]"
def test_list_mcp_tool_from_remote_server_connection_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when MCP server connection fails.
@ -926,9 +907,7 @@ class TestMCPToolManageService:
mcp_provider.authed = True # Provider must be authenticated to test connection errors
mcp_provider.tools = "[]"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt:
@ -942,18 +921,17 @@ class TestMCPToolManageService:
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="Failed to connect to MCP server: Connection failed"):
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Verify database state was not changed
db.session.refresh(mcp_provider)
db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.authed is True # Provider remains authenticated
assert mcp_provider.tools == "[]"
def test_delete_mcp_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_mcp_tool_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful deletion of MCP tool.
@ -974,20 +952,19 @@ class TestMCPToolManageService:
)
# Verify provider exists
from extensions.ext_database import db
assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None
assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None
# Act: Execute the method under test
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
service.delete_provider(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Assert: Verify the expected outcomes
# Provider should be deleted from database
deleted_provider = db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first()
deleted_provider = db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first()
assert deleted_provider is None
def test_delete_mcp_tool_not_found(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_mcp_tool_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test error handling when deleting non-existent MCP tool.
@ -1005,13 +982,14 @@ class TestMCPToolManageService:
non_existent_id = str(fake.uuid4())
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"):
service.delete_provider(tenant_id=tenant.id, provider_id=non_existent_id)
def test_delete_mcp_tool_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_mcp_tool_tenant_isolation(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant isolation when deleting MCP tool.
@ -1036,18 +1014,16 @@ class TestMCPToolManageService:
)
# Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"):
service.delete_provider(tenant_id=tenant2.id, provider_id=mcp_provider1.id)
# Verify provider still exists in tenant1
from extensions.ext_database import db
assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None
assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None
def test_update_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful update of MCP provider.
@ -1070,14 +1046,12 @@ class TestMCPToolManageService:
original_name = mcp_provider.name
original_icon = mcp_provider.icon
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Act: Execute the method under test
from core.entities.mcp_provider import MCPConfiguration
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
service.update_provider(
tenant_id=tenant.id,
provider_id=mcp_provider.id,
@ -1094,7 +1068,7 @@ class TestMCPToolManageService:
)
# Assert: Verify the expected outcomes
db.session.refresh(mcp_provider)
db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.name == "Updated MCP Provider"
assert mcp_provider.server_identifier == "updated_identifier_123"
assert mcp_provider.timeout == 45.0
@ -1108,7 +1082,9 @@ class TestMCPToolManageService:
assert icon_data["content"] == "🚀"
assert icon_data["background"] == "#4ECDC4"
def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_mcp_provider_duplicate_name(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when updating MCP provider with duplicate name.
@ -1134,15 +1110,12 @@ class TestMCPToolManageService:
)
provider2.name = "Second Provider"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Act & Assert: Verify proper error handling for duplicate name
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool First Provider already exists"):
service.update_provider(
tenant_id=tenant.id,
@ -1160,7 +1133,7 @@ class TestMCPToolManageService:
)
def test_update_mcp_provider_credentials_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful update of MCP provider credentials.
@ -1185,9 +1158,7 @@ class TestMCPToolManageService:
mcp_provider.authed = False
mcp_provider.tools = "[]"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock the provider controller and encryption
with (
@ -1202,9 +1173,8 @@ class TestMCPToolManageService:
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
service.update_provider_credentials(
provider_id=mcp_provider.id,
tenant_id=tenant.id,
@ -1213,7 +1183,7 @@ class TestMCPToolManageService:
)
# Assert: Verify the expected outcomes
db.session.refresh(mcp_provider)
db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.authed is True
assert mcp_provider.updated_at is not None
@ -1225,7 +1195,7 @@ class TestMCPToolManageService:
assert "new_key" in credentials
def test_update_mcp_provider_credentials_not_authed(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test update of MCP provider credentials when not authenticated.
@ -1249,9 +1219,7 @@ class TestMCPToolManageService:
mcp_provider.authed = True
mcp_provider.tools = '[{"name": "test_tool"}]'
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock the provider controller and encryption
with (
@ -1266,9 +1234,8 @@ class TestMCPToolManageService:
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
service.update_provider_credentials(
provider_id=mcp_provider.id,
tenant_id=tenant.id,
@ -1277,12 +1244,14 @@ class TestMCPToolManageService:
)
# Assert: Verify the expected outcomes
db.session.refresh(mcp_provider)
db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.authed is False
assert mcp_provider.tools == "[]"
assert mcp_provider.updated_at is not None
def test_re_connect_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_re_connect_mcp_provider_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful reconnection to MCP provider.
@ -1343,7 +1312,9 @@ class TestMCPToolManageService:
sse_read_timeout=mcp_provider.sse_read_timeout,
)
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
def test_re_connect_mcp_provider_auth_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test reconnection to MCP provider when authentication fails.
@ -1385,7 +1356,7 @@ class TestMCPToolManageService:
assert result.encrypted_credentials == "{}"
def test_re_connect_mcp_provider_connection_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test reconnection to MCP provider when connection fails.

View File

@ -2,6 +2,7 @@ from unittest.mock import Mock, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
@ -27,7 +28,7 @@ class TestToolTransformService:
}
def _create_test_tool_provider(
self, db_session_with_containers, mock_external_service_dependencies, provider_type="api"
self, db_session_with_containers: Session, mock_external_service_dependencies, provider_type="api"
):
"""
Helper method to create a test tool provider for testing.
@ -89,14 +90,12 @@ class TestToolTransformService:
else:
raise ValueError(f"Unknown provider type: {provider_type}")
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
db_session_with_containers.add(provider)
db_session_with_containers.commit()
return provider
def test_get_plugin_icon_url_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_plugin_icon_url_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful plugin icon URL generation.
@ -126,7 +125,7 @@ class TestToolTransformService:
assert result == expected_url
def test_get_plugin_icon_url_with_empty_console_url(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test plugin icon URL generation when CONSOLE_API_URL is empty.
@ -156,7 +155,7 @@ class TestToolTransformService:
assert result == expected_url
def test_get_tool_provider_icon_url_builtin_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for builtin providers.
@ -194,7 +193,7 @@ class TestToolTransformService:
assert result == expected_encoded
def test_get_tool_provider_icon_url_api_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for API providers.
@ -220,7 +219,7 @@ class TestToolTransformService:
assert result["content"] == "🔧"
def test_get_tool_provider_icon_url_api_invalid_json(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tool provider icon URL generation for API providers with invalid JSON.
@ -246,7 +245,7 @@ class TestToolTransformService:
assert result["content"] == "😁" or result["content"] == "\ud83d\ude01"
def test_get_tool_provider_icon_url_workflow_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for workflow providers.
@ -271,7 +270,7 @@ class TestToolTransformService:
assert result["content"] == "🔧"
def test_get_tool_provider_icon_url_mcp_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for MCP providers.
@ -296,7 +295,7 @@ class TestToolTransformService:
assert result["content"] == "🔧"
def test_get_tool_provider_icon_url_unknown_type(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tool provider icon URL generation for unknown provider types.
@ -317,7 +316,9 @@ class TestToolTransformService:
# Assert: Verify the expected outcomes
assert result == ""
def test_repack_provider_dict_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_repack_provider_dict_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful provider repacking with dictionary input.
@ -341,7 +342,9 @@ class TestToolTransformService:
# Note: provider name may contain spaces that get URL encoded
assert provider["name"].replace(" ", "%20") in provider["icon"] or provider["name"] in provider["icon"]
def test_repack_provider_entity_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_repack_provider_entity_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful provider repacking with ToolProviderApiEntity input.
@ -389,7 +392,7 @@ class TestToolTransformService:
assert "test_icon_dark.png" in provider.icon_dark
def test_repack_provider_entity_no_plugin_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful provider repacking with ToolProviderApiEntity input without plugin_id.
@ -435,7 +438,9 @@ class TestToolTransformService:
assert provider.icon_dark["background"] == "#252525"
assert provider.icon_dark["content"] == "🔧"
def test_repack_provider_entity_no_dark_icon(self, db_session_with_containers, mock_external_service_dependencies):
def test_repack_provider_entity_no_dark_icon(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test provider repacking with ToolProviderApiEntity input without dark icon.
@ -477,7 +482,7 @@ class TestToolTransformService:
assert provider.icon_dark == ""
def test_builtin_provider_to_user_provider_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion of builtin provider to user provider.
@ -545,7 +550,7 @@ class TestToolTransformService:
assert result.original_credentials == {"api_key": "decrypted_key"}
def test_builtin_provider_to_user_provider_plugin_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion of builtin provider to user provider with plugin.
@ -589,7 +594,7 @@ class TestToolTransformService:
assert result.allow_delete is False
def test_builtin_provider_to_user_provider_no_credentials(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test conversion of builtin provider to user provider without credentials.
@ -630,7 +635,9 @@ class TestToolTransformService:
assert result.allow_delete is False
assert result.masked_credentials == {"api_key": ""}
def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_api_provider_to_controller_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion of API provider to controller.
@ -655,10 +662,8 @@ class TestToolTransformService:
tools_str="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
db_session_with_containers.add(provider)
db_session_with_containers.commit()
# Act: Execute the method under test
result = ToolTransformService.api_provider_to_controller(provider)
@ -669,7 +674,7 @@ class TestToolTransformService:
# Additional assertions would depend on the actual controller implementation
def test_api_provider_to_controller_api_key_query(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test conversion of API provider to controller with api_key_query auth type.
@ -693,10 +698,8 @@ class TestToolTransformService:
tools_str="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
db_session_with_containers.add(provider)
db_session_with_containers.commit()
# Act: Execute the method under test
result = ToolTransformService.api_provider_to_controller(provider)
@ -706,7 +709,7 @@ class TestToolTransformService:
assert hasattr(result, "from_db")
def test_api_provider_to_controller_backward_compatibility(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test conversion of API provider to controller with backward compatibility auth types.
@ -731,10 +734,8 @@ class TestToolTransformService:
tools_str="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
db_session_with_containers.add(provider)
db_session_with_containers.commit()
# Act: Execute the method under test
result = ToolTransformService.api_provider_to_controller(provider)
@ -744,7 +745,7 @@ class TestToolTransformService:
assert hasattr(result, "from_db")
def test_workflow_provider_to_controller_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion of workflow provider to controller.
@ -769,10 +770,8 @@ class TestToolTransformService:
parameter_configuration="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
db_session_with_containers.add(provider)
db_session_with_containers.commit()
# Mock the WorkflowToolProviderController.from_db method to avoid app dependency
with patch("services.tools.tools_transform_service.WorkflowToolProviderController.from_db") as mock_from_db:

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from pydantic import ValidationError
from sqlalchemy.orm import Session
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
@ -63,7 +64,7 @@ class TestWorkflowToolManageService:
"tool_transform_service": mock_tool_transform_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
@ -119,14 +120,12 @@ class TestWorkflowToolManageService:
conversation_variables=[],
)
from extensions.ext_database import db
db.session.add(workflow)
db.session.commit()
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
# Update app to reference the workflow
app.workflow_id = workflow.id
db.session.commit()
db_session_with_containers.commit()
return app, account, workflow
@ -153,7 +152,9 @@ class TestWorkflowToolManageService:
),
]
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_workflow_tool_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful workflow tool creation with valid parameters.
@ -198,11 +199,10 @@ class TestWorkflowToolManageService:
assert result == {"result": "success"}
# Verify database state
from extensions.ext_database import db
# Check if workflow tool provider was created
created_tool_provider = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id,
@ -230,7 +230,7 @@ class TestWorkflowToolManageService:
].workflow_provider_to_controller.assert_called_once()
def test_create_workflow_tool_duplicate_name_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when name already exists.
@ -280,10 +280,9 @@ class TestWorkflowToolManageService:
assert f"Tool with name {first_tool_name} or app_id {app.id} already exists" in str(exc_info.value)
# Verify only one tool was created
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
@ -293,7 +292,7 @@ class TestWorkflowToolManageService:
assert tool_count == 1
def test_create_workflow_tool_invalid_app_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when app does not exist.
@ -331,10 +330,9 @@ class TestWorkflowToolManageService:
assert f"App {non_existent_app_id} not found" in str(exc_info.value)
# Verify no workflow tool was created
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
@ -344,7 +342,7 @@ class TestWorkflowToolManageService:
assert tool_count == 0
def test_create_workflow_tool_invalid_parameters_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when parameters are invalid.
@ -387,10 +385,9 @@ class TestWorkflowToolManageService:
assert "validation error" in str(exc_info.value).lower()
# Verify no workflow tool was created
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
@ -400,7 +397,7 @@ class TestWorkflowToolManageService:
assert tool_count == 0
def test_create_workflow_tool_duplicate_app_id_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when app_id already exists.
@ -450,10 +447,9 @@ class TestWorkflowToolManageService:
assert f"Tool with name {second_tool_name} or app_id {app.id} already exists" in str(exc_info.value)
# Verify only one tool was created
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
@ -463,7 +459,7 @@ class TestWorkflowToolManageService:
assert tool_count == 1
def test_create_workflow_tool_workflow_not_found_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when app has no workflow.
@ -481,10 +477,9 @@ class TestWorkflowToolManageService:
)
# Remove workflow reference from app
from extensions.ext_database import db
app.workflow_id = None
db.session.commit()
db_session_with_containers.commit()
# Attempt to create workflow tool for app without workflow
tool_parameters = self._create_test_workflow_tool_parameters()
@ -505,7 +500,7 @@ class TestWorkflowToolManageService:
# Verify no workflow tool was created
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
@ -515,7 +510,7 @@ class TestWorkflowToolManageService:
assert tool_count == 0
def test_create_workflow_tool_human_input_node_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when workflow contains human input nodes.
@ -558,10 +553,8 @@ class TestWorkflowToolManageService:
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
@ -570,7 +563,9 @@ class TestWorkflowToolManageService:
assert tool_count == 0
def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_workflow_tool_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful workflow tool update with valid parameters.
@ -603,10 +598,9 @@ class TestWorkflowToolManageService:
)
# Get the created tool
from extensions.ext_database import db
created_tool = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id,
@ -641,7 +635,7 @@ class TestWorkflowToolManageService:
assert result == {"result": "success"}
# Verify database state was updated
db.session.refresh(created_tool)
db_session_with_containers.refresh(created_tool)
assert created_tool is not None
assert created_tool.name == updated_tool_name
assert created_tool.label == updated_tool_label
@ -658,7 +652,7 @@ class TestWorkflowToolManageService:
mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called()
def test_update_workflow_tool_human_input_node_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool update fails when workflow contains human input nodes.
@ -689,10 +683,8 @@ class TestWorkflowToolManageService:
parameters=initial_tool_parameters,
)
from extensions.ext_database import db
created_tool = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id,
@ -712,7 +704,7 @@ class TestWorkflowToolManageService:
]
}
)
db.session.commit()
db_session_with_containers.commit()
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
WorkflowToolManageService.update_workflow_tool(
@ -728,10 +720,12 @@ class TestWorkflowToolManageService:
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
db.session.refresh(created_tool)
db_session_with_containers.refresh(created_tool)
assert created_tool.name == original_name
def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_workflow_tool_not_found_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool update fails when tool does not exist.
@ -768,10 +762,9 @@ class TestWorkflowToolManageService:
assert f"Tool {non_existent_tool_id} not found" in str(exc_info.value)
# Verify no workflow tool was created
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
@ -781,7 +774,7 @@ class TestWorkflowToolManageService:
assert tool_count == 0
def test_update_workflow_tool_same_name_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool update succeeds when keeping the same name.
@ -813,10 +806,9 @@ class TestWorkflowToolManageService:
)
# Get the created tool
from extensions.ext_database import db
created_tool = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id,
@ -840,12 +832,12 @@ class TestWorkflowToolManageService:
assert result == {"result": "success"}
# Verify tool still exists with the same name
db.session.refresh(created_tool)
db_session_with_containers.refresh(created_tool)
assert created_tool.name == first_tool_name
assert created_tool.updated_at is not None
def test_create_workflow_tool_with_file_parameter_default(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation with FILE parameter having a file object as default.
@ -916,7 +908,7 @@ class TestWorkflowToolManageService:
assert result == {"result": "success"}
def test_create_workflow_tool_with_files_parameter_default(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation with FILES (Array[File]) parameter having file objects as default.
@ -991,7 +983,7 @@ class TestWorkflowToolManageService:
assert result == {"result": "success"}
def test_create_workflow_tool_db_commit_before_validation(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test that database commit happens before validation, causing DB pollution on validation failure.
@ -1035,10 +1027,9 @@ class TestWorkflowToolManageService:
# Verify the tool was NOT created in database
# This is the expected behavior (no pollution)
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.name == tool_name,