diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index b6d1df319e..6c54be84a8 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,7 +1,7 @@ import flask_restx from flask_restx import Resource, fields, marshal_with from flask_restx._http import HTTPStatus -from sqlalchemy import select +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -33,16 +33,10 @@ api_key_list_model = console_ns.model( def _get_resource(resource_id, tenant_id, resource_model): - if resource_model == App: - with Session(db.engine) as session: - resource = session.execute( - select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) - ).scalar_one_or_none() - else: - with Session(db.engine) as session: - resource = session.execute( - select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) - ).scalar_one_or_none() + with Session(db.engine) as session: + resource = session.execute( + select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) + ).scalar_one_or_none() if resource is None: flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.") @@ -80,10 +74,13 @@ class BaseApiKeyListResource(Resource): resource_id = str(resource_id) _, current_tenant_id = current_account_with_tenant() _get_resource(resource_id, current_tenant_id, self.resource_model) - current_key_count = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) - .count() + current_key_count: int = ( + db.session.scalar( + select(func.count(ApiToken.id)).where( + ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id + ) + ) + or 0 ) if current_key_count >= self.max_keys: @@ -119,14 +116,14 @@ class BaseApiKeyResource(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - key = ( - db.session.query(ApiToken) + key = db.session.scalar( + select(ApiToken) .where( getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id, ) - .first() + .limit(1) ) if key is None: @@ -137,7 +134,7 @@ class BaseApiKeyResource(Resource): assert key is not None # nosec - for type checker only ApiTokenCache.delete(key.token, key.type) - db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() + db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id)) db.session.commit() return {"result": "success"}, 204 diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index e099fe0f32..279e4ec502 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -2,6 +2,7 @@ from typing import Literal from flask import request from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from configs import dify_config from controllers.fastopenapi import console_router @@ -100,6 +101,6 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse: def get_setup_status() -> DifySetup | bool | None: if dify_config.EDITION == "SELF_HOSTED": - return db.session.query(DifySetup).first() + return db.session.scalar(select(DifySetup).limit(1)) return True diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 014f4c4132..6785ba0c34 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -7,6 +7,7 @@ from functools import wraps from typing import ParamSpec, TypeVar from flask import abort, request +from sqlalchemy import select from configs import dify_config from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError @@ -218,13 +219,9 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]: @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs) -> R: # check setup - if ( - dify_config.EDITION == "SELF_HOSTED" - and os.environ.get("INIT_PASSWORD") - and not db.session.query(DifySetup).first() - ): - raise NotInitValidateError() - elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first(): + if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)): + if os.environ.get("INIT_PASSWORD"): + raise NotInitValidateError() raise NotSetupError() return view(*args, **kwargs) diff --git a/api/tests/unit_tests/controllers/console/test_apikey.py b/api/tests/unit_tests/controllers/console/test_apikey.py index 018257f815..c18dd044a7 100644 --- a/api/tests/unit_tests/controllers/console/test_apikey.py +++ b/api/tests/unit_tests/controllers/console/test_apikey.py @@ -114,7 +114,7 @@ class TestBaseApiKeyResource: def test_delete_key_not_found(self, tenant_context_admin, db_mock): resource = DummyApiKeyResource() - db_mock.session.query.return_value.where.return_value.first.return_value = None + db_mock.session.scalar.return_value = None with patch("controllers.console.apikey._get_resource"): with pytest.raises(Exception) as exc_info: @@ -125,7 +125,7 @@ class TestBaseApiKeyResource: def test_delete_success(self, tenant_context_admin, db_mock): resource = DummyApiKeyResource() - db_mock.session.query.return_value.where.return_value.first.return_value = MagicMock() + db_mock.session.scalar.return_value = MagicMock() with ( patch("controllers.console.apikey._get_resource"), diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index 6777077de8..f6e096a97b 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -328,7 +328,7 @@ class TestSystemSetup: def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db): """Test NotInitValidateError when INIT_PASSWORD is set but setup not complete""" # Arrange - mock_db.session.query.return_value.first.return_value = None # No setup + mock_db.session.scalar.return_value = None # No setup mock_environ_get.return_value = "some_password" @setup_required @@ -345,7 +345,7 @@ class TestSystemSetup: def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db): """Test NotSetupError when no INIT_PASSWORD and setup not complete""" # Arrange - mock_db.session.query.return_value.first.return_value = None # No setup + mock_db.session.scalar.return_value = None # No setup mock_environ_get.return_value = None # No INIT_PASSWORD @setup_required