refactor: migrate db.session.query to select in inner_api and web controllers (#33774)

Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo
2026-03-19 19:32:03 +01:00
committed by GitHub
parent f40f6547b4
commit ce370594db
10 changed files with 49 additions and 47 deletions

View File

@ -50,7 +50,7 @@ class TestGetUser:
mock_user.id = "user123"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = mock_user
mock_session.get.return_value = mock_user
# Act
with app.app_context():
@ -58,7 +58,7 @@ class TestGetUser:
# Assert
assert result == mock_user
mock_session.query.assert_called_once()
mock_session.get.assert_called_once()
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session")
@ -72,7 +72,8 @@ class TestGetUser:
mock_user.session_id = "anonymous_session"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = mock_user
# non-anonymous path uses session.get(); anonymous uses session.scalar()
mock_session.get.return_value = mock_user
# Act
with app.app_context():
@ -89,7 +90,7 @@ class TestGetUser:
# Arrange
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = None
mock_session.get.return_value = None
mock_new_user = MagicMock()
mock_enduser_class.return_value = mock_new_user
@ -103,18 +104,20 @@ class TestGetUser:
mock_session.commit.assert_called_once()
mock_session.refresh.assert_called_once()
@patch("controllers.inner_api.plugin.wraps.select")
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session")
@patch("controllers.inner_api.plugin.wraps.db")
def test_should_use_default_session_id_when_user_id_none(
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
self, mock_db, mock_session_class, mock_enduser_class, mock_select, app: Flask
):
"""Test using default session ID when user_id is None"""
# Arrange
mock_user = MagicMock()
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = mock_user
# When user_id is None, is_anonymous=True, so session.scalar() is used
mock_session.scalar.return_value = mock_user
# Act
with app.app_context():
@ -133,7 +136,7 @@ class TestGetUser:
# Arrange
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.side_effect = Exception("Database error")
mock_session.get.side_effect = Exception("Database error")
# Act & Assert
with app.app_context():
@ -161,9 +164,9 @@ class TestGetUserTenant:
# Act
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": "user456"}):
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get:
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_get.return_value = mock_tenant
mock_get_user.return_value = mock_user
result = protected_view()
@ -194,8 +197,8 @@ class TestGetUserTenant:
# Act & Assert
with app.test_request_context(json={"tenant_id": "nonexistent", "user_id": "user456"}):
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
mock_query.return_value.where.return_value.first.return_value = None
with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get:
mock_get.return_value = None
with pytest.raises(ValueError, match="tenant not found"):
protected_view()
@ -215,9 +218,9 @@ class TestGetUserTenant:
# Act - use empty string for user_id to trigger default logic
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": ""}):
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get:
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_get.return_value = mock_tenant
mock_get_user.return_value = mock_user
result = protected_view()

View File

@ -249,8 +249,8 @@ class TestEnterpriseInnerApiUserAuth:
headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key}
):
with patch.object(dify_config, "INNER_API", True):
with patch("controllers.inner_api.wraps.db.session.query") as mock_query:
mock_query.return_value.where.return_value.first.return_value = mock_user
with patch("controllers.inner_api.wraps.db.session.get") as mock_get:
mock_get.return_value = mock_user
result = protected_view()
# Assert

View File

@ -91,7 +91,7 @@ class TestEnterpriseWorkspace:
# Arrange
mock_account = MagicMock()
mock_account.email = "owner@example.com"
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
mock_db.session.scalar.return_value = mock_account
now = datetime(2025, 1, 1, 12, 0, 0)
mock_tenant = MagicMock()
@ -122,7 +122,7 @@ class TestEnterpriseWorkspace:
def test_post_returns_404_when_owner_not_found(self, mock_db, api_instance, app: Flask):
"""Test that post() returns 404 when the owner account does not exist"""
# Arrange
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
# Act
unwrapped_post = inspect.unwrap(api_instance.post)

View File

@ -49,6 +49,17 @@ class _FakeSession:
assert self._model_name is not None
return self._mapping.get(self._model_name)
def get(self, model, ident):
return self._mapping.get(model.__name__)
def scalar(self, stmt):
# Extract the model name from the select statement's column_descriptions
try:
name = stmt.column_descriptions[0]["entity"].__name__
except (AttributeError, IndexError, KeyError):
return None
return self._mapping.get(name)
class _FakeDB:
"""Minimal db stub exposing engine and session."""

View File

@ -50,7 +50,7 @@ class TestAppSiteApi:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
site_obj = _site()
mock_db.session.query.return_value.where.return_value.first.return_value = site_obj
mock_db.session.scalar.return_value = site_obj
tenant = _tenant()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
end_user = SimpleNamespace(id="eu-1")
@ -66,9 +66,9 @@ class TestAppSiteApi:
@patch("controllers.web.site.db")
def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
mock_db.session.query.return_value.where.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
tenant = _tenant()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
end_user = SimpleNamespace(id="eu-1")
with app.test_request_context("/site"):
@ -80,7 +80,7 @@ class TestAppSiteApi:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
from models.account import TenantStatus
mock_db.session.query.return_value.where.return_value.first.return_value = _site()
mock_db.session.scalar.return_value = _site()
tenant = SimpleNamespace(
id="tenant-1",
status=TenantStatus.ARCHIVE,