From ce370594db45459b4963c76bf78ca5519ce599dc Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Thu, 19 Mar 2026 19:32:03 +0100 Subject: [PATCH] refactor: migrate db.session.query to select in inner_api and web controllers (#33774) Co-authored-by: Asuka Minato Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/inner_api/plugin/wraps.py | 27 ++++------------- .../inner_api/workspace/workspace.py | 3 +- api/controllers/inner_api/wraps.py | 2 +- api/controllers/web/human_input_form.py | 5 ++-- api/controllers/web/site.py | 3 +- .../inner_api/plugin/test_plugin_wraps.py | 29 ++++++++++--------- .../controllers/inner_api/test_auth_wraps.py | 4 +-- .../inner_api/workspace/test_workspace.py | 4 +-- .../controllers/web/test_human_input_form.py | 11 +++++++ .../unit_tests/controllers/web/test_site.py | 8 ++--- 10 files changed, 49 insertions(+), 47 deletions(-) diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 766d95b3dd..d6e3ebfbcd 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -5,6 +5,7 @@ from typing import ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in from pydantic import BaseModel +from sqlalchemy import select from sqlalchemy.orm import Session from extensions.ext_database import db @@ -36,23 +37,16 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: user_model = None if is_anonymous: - user_model = ( - session.query(EndUser) + user_model = session.scalar( + select(EndUser) .where( EndUser.session_id == user_id, EndUser.tenant_id == tenant_id, ) - .first() + .limit(1) ) else: - user_model = ( - session.query(EndUser) - .where( - EndUser.id == user_id, - EndUser.tenant_id == tenant_id, - ) - .first() - ) + user_model = session.get(EndUser, user_id) if not user_model: user_model = EndUser( @@ -85,16 +79,7 @@ def get_user_tenant(view_func: Callable[P, R]): if not user_id: user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID - try: - tenant_model = ( - db.session.query(Tenant) - .where( - Tenant.id == tenant_id, - ) - .first() - ) - except Exception: - raise ValueError("tenant not found") + tenant_model = db.session.get(Tenant, tenant_id) if not tenant_model: raise ValueError("tenant not found") diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index a5746abafa..ef0a46db63 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -2,6 +2,7 @@ import json from flask_restx import Resource from pydantic import BaseModel +from sqlalchemy import select from controllers.common.schema import register_schema_models from controllers.console.wraps import setup_required @@ -42,7 +43,7 @@ class EnterpriseWorkspace(Resource): def post(self): args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {}) - account = db.session.query(Account).filter_by(email=args.owner_email).first() + account = db.session.scalar(select(Account).where(Account.email == args.owner_email).limit(1)) if account is None: return {"message": "owner account not found."}, 404 diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 4bdcc6832a..7c60b316e8 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -75,7 +75,7 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]): if signature_base64 != token: return view(*args, **kwargs) - kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first() + kwargs["user"] = db.session.get(EndUser, user_id) return view(*args, **kwargs) diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py index 4e69e56025..36728a47d1 100644 --- a/api/controllers/web/human_input_form.py +++ b/api/controllers/web/human_input_form.py @@ -8,6 +8,7 @@ from datetime import datetime from flask import Response, request from flask_restx import Resource, reqparse +from sqlalchemy import select from werkzeug.exceptions import Forbidden from configs import dify_config @@ -147,11 +148,11 @@ class HumanInputFormApi(Resource): def _get_app_site_from_form(form: Form) -> tuple[App, Site]: """Resolve App/Site for the form's app and validate tenant status.""" - app_model = db.session.query(App).where(App.id == form.app_id).first() + app_model = db.session.get(App, form.app_id) if app_model is None or app_model.tenant_id != form.tenant_id: raise NotFoundError("Form not found") - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if site is None: raise Forbidden() diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index f957229ece..1a0c6d4252 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,6 +1,7 @@ from typing import cast from flask_restx import fields, marshal, marshal_with +from sqlalchemy import select from werkzeug.exceptions import Forbidden from configs import dify_config @@ -72,7 +73,7 @@ class AppSiteApi(WebApiResource): def get(self, app_model, end_user): """Retrieve app site info.""" # get site - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise Forbidden() diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py index 6de07a23e5..eac57fe4b7 100644 --- a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py @@ -50,7 +50,7 @@ class TestGetUser: mock_user.id = "user123" mock_session = MagicMock() mock_session_class.return_value.__enter__.return_value = mock_session - mock_session.query.return_value.where.return_value.first.return_value = mock_user + mock_session.get.return_value = mock_user # Act with app.app_context(): @@ -58,7 +58,7 @@ class TestGetUser: # Assert assert result == mock_user - mock_session.query.assert_called_once() + mock_session.get.assert_called_once() @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.Session") @@ -72,7 +72,8 @@ class TestGetUser: mock_user.session_id = "anonymous_session" mock_session = MagicMock() mock_session_class.return_value.__enter__.return_value = mock_session - mock_session.query.return_value.where.return_value.first.return_value = mock_user + # non-anonymous path uses session.get(); anonymous uses session.scalar() + mock_session.get.return_value = mock_user # Act with app.app_context(): @@ -89,7 +90,7 @@ class TestGetUser: # Arrange mock_session = MagicMock() mock_session_class.return_value.__enter__.return_value = mock_session - mock_session.query.return_value.where.return_value.first.return_value = None + mock_session.get.return_value = None mock_new_user = MagicMock() mock_enduser_class.return_value = mock_new_user @@ -103,18 +104,20 @@ class TestGetUser: mock_session.commit.assert_called_once() mock_session.refresh.assert_called_once() + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.Session") @patch("controllers.inner_api.plugin.wraps.db") def test_should_use_default_session_id_when_user_id_none( - self, mock_db, mock_session_class, mock_enduser_class, app: Flask + self, mock_db, mock_session_class, mock_enduser_class, mock_select, app: Flask ): """Test using default session ID when user_id is None""" # Arrange mock_user = MagicMock() mock_session = MagicMock() mock_session_class.return_value.__enter__.return_value = mock_session - mock_session.query.return_value.where.return_value.first.return_value = mock_user + # When user_id is None, is_anonymous=True, so session.scalar() is used + mock_session.scalar.return_value = mock_user # Act with app.app_context(): @@ -133,7 +136,7 @@ class TestGetUser: # Arrange mock_session = MagicMock() mock_session_class.return_value.__enter__.return_value = mock_session - mock_session.query.side_effect = Exception("Database error") + mock_session.get.side_effect = Exception("Database error") # Act & Assert with app.app_context(): @@ -161,9 +164,9 @@ class TestGetUserTenant: # Act with app.test_request_context(json={"tenant_id": "tenant123", "user_id": "user456"}): monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False) - with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query: + with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get: with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user: - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_get_user.return_value = mock_user result = protected_view() @@ -194,8 +197,8 @@ class TestGetUserTenant: # Act & Assert with app.test_request_context(json={"tenant_id": "nonexistent", "user_id": "user456"}): - with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query: - mock_query.return_value.where.return_value.first.return_value = None + with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get: + mock_get.return_value = None with pytest.raises(ValueError, match="tenant not found"): protected_view() @@ -215,9 +218,9 @@ class TestGetUserTenant: # Act - use empty string for user_id to trigger default logic with app.test_request_context(json={"tenant_id": "tenant123", "user_id": ""}): monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False) - with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query: + with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get: with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user: - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_get_user.return_value = mock_user result = protected_view() diff --git a/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py index 883ccdea2c..efe1841f08 100644 --- a/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py +++ b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py @@ -249,8 +249,8 @@ class TestEnterpriseInnerApiUserAuth: headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key} ): with patch.object(dify_config, "INNER_API", True): - with patch("controllers.inner_api.wraps.db.session.query") as mock_query: - mock_query.return_value.where.return_value.first.return_value = mock_user + with patch("controllers.inner_api.wraps.db.session.get") as mock_get: + mock_get.return_value = mock_user result = protected_view() # Assert diff --git a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py index 4fbf0f7125..56a8f94963 100644 --- a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py @@ -91,7 +91,7 @@ class TestEnterpriseWorkspace: # Arrange mock_account = MagicMock() mock_account.email = "owner@example.com" - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account + mock_db.session.scalar.return_value = mock_account now = datetime(2025, 1, 1, 12, 0, 0) mock_tenant = MagicMock() @@ -122,7 +122,7 @@ class TestEnterpriseWorkspace: def test_post_returns_404_when_owner_not_found(self, mock_db, api_instance, app: Flask): """Test that post() returns 404 when the owner account does not exist""" # Arrange - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act unwrapped_post = inspect.unwrap(api_instance.post) diff --git a/api/tests/unit_tests/controllers/web/test_human_input_form.py b/api/tests/unit_tests/controllers/web/test_human_input_form.py index 4fb735b033..a1dbc80b20 100644 --- a/api/tests/unit_tests/controllers/web/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/web/test_human_input_form.py @@ -49,6 +49,17 @@ class _FakeSession: assert self._model_name is not None return self._mapping.get(self._model_name) + def get(self, model, ident): + return self._mapping.get(model.__name__) + + def scalar(self, stmt): + # Extract the model name from the select statement's column_descriptions + try: + name = stmt.column_descriptions[0]["entity"].__name__ + except (AttributeError, IndexError, KeyError): + return None + return self._mapping.get(name) + class _FakeDB: """Minimal db stub exposing engine and session.""" diff --git a/api/tests/unit_tests/controllers/web/test_site.py b/api/tests/unit_tests/controllers/web/test_site.py index 557bf93e9e..6e9d754c43 100644 --- a/api/tests/unit_tests/controllers/web/test_site.py +++ b/api/tests/unit_tests/controllers/web/test_site.py @@ -50,7 +50,7 @@ class TestAppSiteApi: app.config["RESTX_MASK_HEADER"] = "X-Fields" mock_features.return_value = SimpleNamespace(can_replace_logo=False) site_obj = _site() - mock_db.session.query.return_value.where.return_value.first.return_value = site_obj + mock_db.session.scalar.return_value = site_obj tenant = _tenant() app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) end_user = SimpleNamespace(id="eu-1") @@ -66,9 +66,9 @@ class TestAppSiteApi: @patch("controllers.web.site.db") def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None tenant = _tenant() - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) end_user = SimpleNamespace(id="eu-1") with app.test_request_context("/site"): @@ -80,7 +80,7 @@ class TestAppSiteApi: app.config["RESTX_MASK_HEADER"] = "X-Fields" from models.account import TenantStatus - mock_db.session.query.return_value.where.return_value.first.return_value = _site() + mock_db.session.scalar.return_value = _site() tenant = SimpleNamespace( id="tenant-1", status=TenantStatus.ARCHIVE,