mirror of
https://github.com/langgenius/dify.git
synced 2026-03-25 00:07:56 +08:00
refactor: migrate db.session.query to select in inner_api and web controllers (#33774)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -5,6 +5,7 @@ from typing import ParamSpec, TypeVar
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -36,23 +37,16 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
user_model = None
|
||||
|
||||
if is_anonymous:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
user_model = session.scalar(
|
||||
select(EndUser)
|
||||
.where(
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
user_model = session.get(EndUser, user_id)
|
||||
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
@ -85,16 +79,7 @@ def get_user_tenant(view_func: Callable[P, R]):
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
try:
|
||||
tenant_model = (
|
||||
db.session.query(Tenant)
|
||||
.where(
|
||||
Tenant.id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("tenant not found")
|
||||
tenant_model = db.session.get(Tenant, tenant_id)
|
||||
|
||||
if not tenant_model:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.wraps import setup_required
|
||||
@ -42,7 +43,7 @@ class EnterpriseWorkspace(Resource):
|
||||
def post(self):
|
||||
args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {})
|
||||
|
||||
account = db.session.query(Account).filter_by(email=args.owner_email).first()
|
||||
account = db.session.scalar(select(Account).where(Account.email == args.owner_email).limit(1))
|
||||
if account is None:
|
||||
return {"message": "owner account not found."}, 404
|
||||
|
||||
|
||||
@ -75,7 +75,7 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]):
|
||||
if signature_base64 != token:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first()
|
||||
kwargs["user"] = db.session.get(EndUser, user_id)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from datetime import datetime
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@ -147,11 +148,11 @@ class HumanInputFormApi(Resource):
|
||||
|
||||
def _get_app_site_from_form(form: Form) -> tuple[App, Site]:
|
||||
"""Resolve App/Site for the form's app and validate tenant status."""
|
||||
app_model = db.session.query(App).where(App.id == form.app_id).first()
|
||||
app_model = db.session.get(App, form.app_id)
|
||||
if app_model is None or app_model.tenant_id != form.tenant_id:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
if site is None:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import cast
|
||||
|
||||
from flask_restx import fields, marshal, marshal_with
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@ -72,7 +73,7 @@ class AppSiteApi(WebApiResource):
|
||||
def get(self, app_model, end_user):
|
||||
"""Retrieve app site info."""
|
||||
# get site
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
@ -50,7 +50,7 @@ class TestGetUser:
|
||||
mock_user.id = "user123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
mock_session.get.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
@ -58,7 +58,7 @@ class TestGetUser:
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
mock_session.query.assert_called_once()
|
||||
mock_session.get.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@ -72,7 +72,8 @@ class TestGetUser:
|
||||
mock_user.session_id = "anonymous_session"
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
# non-anonymous path uses session.get(); anonymous uses session.scalar()
|
||||
mock_session.get.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
@ -89,7 +90,7 @@ class TestGetUser:
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_session.get.return_value = None
|
||||
mock_new_user = MagicMock()
|
||||
mock_enduser_class.return_value = mock_new_user
|
||||
|
||||
@ -103,18 +104,20 @@ class TestGetUser:
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.select")
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_use_default_session_id_when_user_id_none(
|
||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
||||
self, mock_db, mock_session_class, mock_enduser_class, mock_select, app: Flask
|
||||
):
|
||||
"""Test using default session ID when user_id is None"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
# When user_id is None, is_anonymous=True, so session.scalar() is used
|
||||
mock_session.scalar.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
@ -133,7 +136,7 @@ class TestGetUser:
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.side_effect = Exception("Database error")
|
||||
mock_session.get.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with app.app_context():
|
||||
@ -161,9 +164,9 @@ class TestGetUserTenant:
|
||||
# Act
|
||||
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": "user456"}):
|
||||
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get:
|
||||
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
mock_get_user.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
@ -194,8 +197,8 @@ class TestGetUserTenant:
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(json={"tenant_id": "nonexistent", "user_id": "user456"}):
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get:
|
||||
mock_get.return_value = None
|
||||
with pytest.raises(ValueError, match="tenant not found"):
|
||||
protected_view()
|
||||
|
||||
@ -215,9 +218,9 @@ class TestGetUserTenant:
|
||||
# Act - use empty string for user_id to trigger default logic
|
||||
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": ""}):
|
||||
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get:
|
||||
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
mock_get_user.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
|
||||
@ -249,8 +249,8 @@ class TestEnterpriseInnerApiUserAuth:
|
||||
headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key}
|
||||
):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch("controllers.inner_api.wraps.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_user
|
||||
with patch("controllers.inner_api.wraps.db.session.get") as mock_get:
|
||||
mock_get.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
|
||||
@ -91,7 +91,7 @@ class TestEnterpriseWorkspace:
|
||||
# Arrange
|
||||
mock_account = MagicMock()
|
||||
mock_account.email = "owner@example.com"
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
mock_db.session.scalar.return_value = mock_account
|
||||
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
mock_tenant = MagicMock()
|
||||
@ -122,7 +122,7 @@ class TestEnterpriseWorkspace:
|
||||
def test_post_returns_404_when_owner_not_found(self, mock_db, api_instance, app: Flask):
|
||||
"""Test that post() returns 404 when the owner account does not exist"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
unwrapped_post = inspect.unwrap(api_instance.post)
|
||||
|
||||
@ -49,6 +49,17 @@ class _FakeSession:
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
|
||||
def get(self, model, ident):
|
||||
return self._mapping.get(model.__name__)
|
||||
|
||||
def scalar(self, stmt):
|
||||
# Extract the model name from the select statement's column_descriptions
|
||||
try:
|
||||
name = stmt.column_descriptions[0]["entity"].__name__
|
||||
except (AttributeError, IndexError, KeyError):
|
||||
return None
|
||||
return self._mapping.get(name)
|
||||
|
||||
|
||||
class _FakeDB:
|
||||
"""Minimal db stub exposing engine and session."""
|
||||
|
||||
@ -50,7 +50,7 @@ class TestAppSiteApi:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
|
||||
site_obj = _site()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = site_obj
|
||||
mock_db.session.scalar.return_value = site_obj
|
||||
tenant = _tenant()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
@ -66,9 +66,9 @@ class TestAppSiteApi:
|
||||
@patch("controllers.web.site.db")
|
||||
def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
tenant = _tenant()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
@ -80,7 +80,7 @@ class TestAppSiteApi:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
from models.account import TenantStatus
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = _site()
|
||||
mock_db.session.scalar.return_value = _site()
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=TenantStatus.ARCHIVE,
|
||||
|
||||
Reference in New Issue
Block a user