From 98de360447d93eeec73290e14feecf8952e92d94 Mon Sep 17 00:00:00 2001 From: "yunlu.wen" Date: Sat, 23 May 2026 14:11:52 +0800 Subject: [PATCH] refactor: move db query from api leyer to service layer --- api/controllers/openapi/account.py | 4 +- api/controllers/openapi/apps.py | 34 ++--- .../openapi/apps_permitted_external.py | 17 +-- api/controllers/openapi/auth/steps.py | 8 +- api/controllers/openapi/auth/strategies.py | 24 +--- api/controllers/openapi/oauth_device.py | 14 +- api/controllers/openapi/oauth_device_sso.py | 31 +---- api/services/account_service.py | 78 +++++++++++ api/services/app_service.py | 48 +++++++ api/services/oauth_device_flow.py | 7 +- .../openapi/auth/test_step_authz.py | 7 +- .../controllers/openapi/test_device_sso.py | 26 ---- .../services/test_account_service.py | 91 +++++++++++- .../unit_tests/services/test_app_service.py | 129 ++++++++++++++++++ 14 files changed, 380 insertions(+), 138 deletions(-) create mode 100644 api/tests/unit_tests/services/test_app_service.py diff --git a/api/controllers/openapi/account.py b/api/controllers/openapi/account.py index b6ed0dae51..602d7e7ab4 100644 --- a/api/controllers/openapi/account.py +++ b/api/controllers/openapi/account.py @@ -62,9 +62,7 @@ class AccountApi(Resource): ).model_dump(mode="json") account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None - memberships = ( - TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else [] - ) + memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else [] default_ws_id = _pick_default_workspace(memberships) return AccountResponse( diff --git a/api/controllers/openapi/apps.py b/api/controllers/openapi/apps.py index d42961851b..8a3fc81809 100644 --- a/api/controllers/openapi/apps.py +++ b/api/controllers/openapi/apps.py @@ -10,7 +10,6 @@ from __future__ import annotations import uuid as _uuid from typing import Any, cast -import sqlalchemy as sa from flask import request from flask_restx import Resource from pydantic import ValidationError @@ -43,9 +42,9 @@ from libs.oauth_bearer import ( require_workspace_member, validate_bearer, ) -from models import App, Tenant +from models import App +from services.account_service import TenantService from services.app_service import AppListParams, AppService -from services.openapi.visibility import apply_openapi_gate, is_openapi_visible from services.tag_service import TagService _APPS_READ_DECORATORS = [ @@ -82,23 +81,14 @@ class AppReadResource(Resource): is_uuid = False if is_uuid: - app = db.session.get(App, str(parsed_uuid)) # normalised dashed form - if not app or app.status != "normal" or not is_openapi_visible(app): + # ``str(parsed_uuid)`` normalises to the canonical dashed form. + app = AppService.get_visible_app_by_id(db.session, str(parsed_uuid)) + if app is None: raise NotFound("app not found") else: if not workspace_id: raise UnprocessableEntity("workspace_id is required for name-based lookup") - matches = list( - db.session.execute( - apply_openapi_gate( - sa.select(App).where( - App.name == app_id, - App.tenant_id == workspace_id, - App.status == "normal", - ) - ) - ).scalars() - ) + matches = AppService.find_visible_apps_by_name(db.session, name=app_id, tenant_id=workspace_id) if len(matches) == 0: raise NotFound("app not found") if len(matches) > 1: @@ -210,12 +200,10 @@ class AppListApi(Resource): tenant_name: str | None = None if parsed_uuid is not None: - app: App | None = db.session.get(App, str(parsed_uuid)) - if not app or app.status != "normal" or str(app.tenant_id) != workspace_id or not is_openapi_visible(app): + app: App | None = AppService.get_visible_app_by_id(db.session, str(parsed_uuid)) + if app is None or str(app.tenant_id) != workspace_id: return empty - tenant_name = db.session.execute( - sa.select(Tenant.name).where(Tenant.id == workspace_id) - ).scalar_one_or_none() + tenant_name = TenantService.get_tenant_name(db.session, workspace_id) item = AppListRow( id=str(app.id), name=app.name, @@ -255,9 +243,7 @@ class AppListApi(Resource): tenant_name = None if pagination.items: - tenant_name = db.session.execute( - sa.select(Tenant.name).where(Tenant.id == workspace_id) - ).scalar_one_or_none() + tenant_name = TenantService.get_tenant_name(db.session, workspace_id) items = [ AppListRow( diff --git a/api/controllers/openapi/apps_permitted_external.py b/api/controllers/openapi/apps_permitted_external.py index 152ef8aee7..9359dca228 100644 --- a/api/controllers/openapi/apps_permitted_external.py +++ b/api/controllers/openapi/apps_permitted_external.py @@ -7,7 +7,6 @@ EE blueprint chain so this module is unreachable there. from __future__ import annotations -import sqlalchemy as sa from flask import request from flask_restx import Resource from pydantic import ValidationError @@ -29,10 +28,11 @@ from libs.oauth_bearer import ( require_scope, validate_bearer, ) -from models import App, Tenant +from models import App +from services.account_service import TenantService +from services.app_service import AppService from services.enterprise.app_permitted_service import list_permitted_apps from services.openapi.license_gate import license_required -from services.openapi.visibility import apply_openapi_gate @openapi_ns.route("/permitted-external-apps") @@ -68,15 +68,10 @@ class PermittedExternalAppsListApi(Resource): return env.model_dump(mode="json"), 200 apps_by_id: dict[str, App] = { - str(a.id): a - for a in db.session.execute(apply_openapi_gate(sa.select(App).where(App.id.in_(page_result.app_ids)))) - .scalars() - .all() - } - tenant_ids = list({a.tenant_id for a in apps_by_id.values()}) - tenants_by_id = { - str(t.id): t for t in db.session.execute(sa.select(Tenant).where(Tenant.id.in_(tenant_ids))).scalars().all() + str(a.id): a for a in AppService.find_visible_apps_by_ids(db.session, page_result.app_ids) } + tenant_ids = list({str(a.tenant_id) for a in apps_by_id.values()}) + tenants_by_id = {str(t.id): t for t in TenantService.get_tenants_by_ids(db.session, tenant_ids)} items: list[AppListRow] = [] for app_id in page_result.app_ids: diff --git a/api/controllers/openapi/auth/steps.py b/api/controllers/openapi/auth/steps.py index 377ffe5300..40a168b489 100644 --- a/api/controllers/openapi/auth/steps.py +++ b/api/controllers/openapi/auth/steps.py @@ -30,7 +30,9 @@ from libs.oauth_bearer import ( get_authenticator, set_auth_ctx, ) -from models import App, Tenant, TenantStatus +from models import TenantStatus +from services.account_service import TenantService +from services.app_service import AppService class BearerCheck: @@ -97,12 +99,12 @@ class AppResolver: app_id = ctx.path_params.get("app_id") if not app_id: raise BadRequest("app_id is required in path") - app = db.session.get(App, app_id) + app = AppService.get_app_by_id(db.session, app_id) if not app or app.status != "normal": raise NotFound("app not found") if not app.enable_api: raise Forbidden("service_api_disabled") - tenant = db.session.get(Tenant, app.tenant_id) + tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id)) if tenant is None or tenant.status == TenantStatus.ARCHIVE: raise Forbidden("workspace unavailable") ctx.app, ctx.tenant = app, tenant diff --git a/api/controllers/openapi/auth/strategies.py b/api/controllers/openapi/auth/strategies.py index eb669c62f8..aaaaadd948 100644 --- a/api/controllers/openapi/auth/strategies.py +++ b/api/controllers/openapi/auth/strategies.py @@ -7,18 +7,16 @@ composition stays a flat list. from __future__ import annotations -import uuid from typing import Protocol from flask import current_app from flask_login import user_logged_in -from sqlalchemy import select from controllers.openapi.auth.context import Context from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.oauth_bearer import SubjectType -from models import Account, TenantAccountJoin +from services.account_service import AccountService, TenantService from services.end_user_service import EndUserService from services.enterprise.enterprise_service import ( EnterpriseService, @@ -106,9 +104,7 @@ class AclStrategy: return str(ctx.account_id) if ctx.account_id is not None else None if ctx.subject_email is None: return None - account = db.session.execute( - select(Account).where(Account.email == ctx.subject_email), - ).scalar_one_or_none() + account = AccountService.get_account_by_email(db.session, ctx.subject_email) return str(account.id) if account is not None else None @@ -125,19 +121,7 @@ class MembershipStrategy: return False if ctx.tenant is None: return False - return _has_tenant_membership(ctx.account_id, ctx.tenant.id) - - -def _has_tenant_membership(account_id: uuid.UUID | str | None, tenant_id: str) -> bool: - if not account_id: - return False - row = db.session.execute( - select(TenantAccountJoin.id).where( - TenantAccountJoin.tenant_id == tenant_id, - TenantAccountJoin.account_id == account_id, - ) - ).scalar_one_or_none() - return row is not None + return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id) def _login_as(user) -> None: @@ -159,7 +143,7 @@ class AccountMounter: def mount(self, ctx: Context) -> None: if ctx.account_id is None: raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run") - account = db.session.get(Account, ctx.account_id) + account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if account is None: raise RuntimeError("AccountMounter: account row missing for resolved bearer") account.current_tenant = ctx.must_tenant diff --git a/api/controllers/openapi/oauth_device.py b/api/controllers/openapi/oauth_device.py index 87ca64f0e7..bbee345767 100644 --- a/api/controllers/openapi/oauth_device.py +++ b/api/controllers/openapi/oauth_device.py @@ -50,6 +50,7 @@ from libs.rate_limit import ( LIMIT_LOOKUP_PUBLIC, rate_limit, ) +from services.account_service import TenantService from services.oauth_device_flow import ( ACCOUNT_ISSUER_SENTINEL, DEFAULT_POLL_INTERVAL_SECONDS, @@ -333,18 +334,7 @@ def _audit_cross_ip_if_needed(state) -> None: def _build_account_poll_payload(account, tenant, mint) -> PollPayload: - """Account branch of the shared `PollPayload` contract. SSO-only fields - (`subject_email`, `subject_issuer`) are intentionally omitted; see the - `PollPayload` docstring in `services.oauth_device_flow`. - """ - from models import Tenant, TenantAccountJoin - - rows = ( - db.session.query(Tenant, TenantAccountJoin) - .join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id) - .filter(TenantAccountJoin.account_id == account.id) - .all() - ) + rows = TenantService.get_workspaces_for_account(db.session, str(account.id)) workspaces = [WorkspacePayload(id=str(t.id), name=t.name, role=getattr(m, "role", "")) for t, m in rows] # Prefer active session tenant → DB-flagged current join → first membership. default_ws_id = None diff --git a/api/controllers/openapi/oauth_device_sso.py b/api/controllers/openapi/oauth_device_sso.py index 49866e1156..08ecce0a38 100644 --- a/api/controllers/openapi/oauth_device_sso.py +++ b/api/controllers/openapi/oauth_device_sso.py @@ -17,7 +17,6 @@ import secrets from dataclasses import dataclass from flask import jsonify, make_response, redirect, request -from sqlalchemy import func, select from werkzeug.exceptions import ( BadGateway, BadRequest, @@ -49,8 +48,7 @@ from libs.rate_limit import ( enforce, rate_limit, ) -from models import Account -from models.account import AccountStatus +from services.account_service import AccountService from services.enterprise.enterprise_service import EnterpriseService from services.oauth_device_flow import ( DeviceFlowRedis, @@ -149,7 +147,7 @@ def sso_complete(): if state.status is not DeviceFlowStatus.PENDING: raise Conflict("user_code_not_pending") - if _email_belongs_to_dify_account(claims["email"]): + if AccountService.has_active_account_with_email(db.session, claims["email"]): _emit_external_rejection_audit( state, _RejectedClaims(subject_email=claims["email"], subject_issuer=claims["issuer"]), @@ -229,7 +227,7 @@ def approve_external(): if state.status is not DeviceFlowStatus.PENDING: raise Conflict("user_code_not_pending") - if _email_belongs_to_dify_account(claims.subject_email): + if AccountService.has_active_account_with_email(db.session, claims.subject_email): _emit_external_rejection_audit(state, claims, reason="email_belongs_to_dify_account") raise Forbidden("email_belongs_to_dify_account") @@ -308,29 +306,6 @@ class _RejectedClaims: subject_issuer: str -def _email_belongs_to_dify_account(email: str) -> bool: - """External SSO subjects whose email matches an active Dify Account must - authenticate via the internal Dify login path (which mints dfoa_), not via - the external SSO device flow. Returning True here blocks dfoe_ minting. - - Pending/uninitialized/banned/closed accounts do not block: pending and - uninitialized users may complete invitation via SSO; banned and closed - accounts are handled by separate enforcement paths. - """ - if not email: - return False - normalized = email.strip().lower() - if not normalized: - return False - row = db.session.execute( - select(Account.id).where( - func.lower(Account.email) == normalized, - Account.status == AccountStatus.ACTIVE, - ), - ).scalar_one_or_none() - return row is not None - - def _emit_external_rejection_audit(state, claims, *, reason: str) -> None: logger.warning( "audit: oauth.device_flow_rejected subject_type=%s subject_email=%s subject_issuer=%s reason=%s", diff --git a/api/services/account_service.py b/api/services/account_service.py index 4d1ebfbe6e..344b3619f2 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -164,6 +164,29 @@ class AccountService: redis_client.delete(AccountService._get_refresh_token_key(refresh_token)) redis_client.delete(AccountService._get_account_refresh_token_key(account_id)) + @staticmethod + def get_account_by_email(session: Session | scoped_session, email: str) -> Account | None: + """Plain ``Account`` getter keyed by email. Case-sensitive — use + :meth:`has_active_account_with_email` for the case-insensitive + existence check that backs the SSO collision rule. + """ + return session.execute(select(Account).where(Account.email == email)).scalar_one_or_none() + + @staticmethod + def has_active_account_with_email(session: Session | scoped_session, email: str) -> bool: + if not email: + return False + normalized = email.strip().lower() + if not normalized: + return False + row = session.execute( + select(Account.id).where( + func.lower(Account.email) == normalized, + Account.status == AccountStatus.ACTIVE, + ) + ).scalar_one_or_none() + return row is not None + @staticmethod def get_account_by_id(session: Session | scoped_session, account_id: str) -> Account | None: """Plain ``Account`` getter — no banned check, no tenant rotation, @@ -1241,6 +1264,61 @@ class TenantService: ).all() ) + @staticmethod + def account_belongs_to_tenant( + session: Session | scoped_session, + account_id: uuid.UUID | str | None, + tenant_id: str, + ) -> bool: + """Existence check for ``TenantAccountJoin(account_id, tenant_id)``. + Backs the CE-deployment membership fallback in + ``controllers.openapi.auth.strategies.MembershipStrategy``. + + ``None``/empty ``account_id`` short-circuits to ``False`` so SSO + bearers (no account) and missing identity collapse cleanly. + """ + if not account_id: + return False + row = session.execute( + select(TenantAccountJoin.id).where( + TenantAccountJoin.tenant_id == tenant_id, + TenantAccountJoin.account_id == account_id, + ) + ).scalar_one_or_none() + return row is not None + + @staticmethod + def get_tenant_by_id(session: Session | scoped_session, tenant_id: str) -> Tenant | None: + """Plain ``session.get(Tenant, tenant_id)`` — no status filter. + Callers map ``status == ARCHIVE`` to their own error code (the + openapi auth pipeline raises 403 ``workspace unavailable``). + """ + return session.get(Tenant, tenant_id) + + @staticmethod + def get_tenants_by_ids( + session: Session | scoped_session, + tenant_ids: list[str], + ) -> list[Tenant]: + """Bulk ``Tenant`` fetch by primary-key list. Order is unspecified + — callers index by ``tenant.id`` (e.g. for cross-tenant denorm + in ``/openapi/v1/permitted-external-apps``). + + Empty input short-circuits to ``[]`` to avoid emitting an + ``IN ()`` SQL fragment. + """ + if not tenant_ids: + return [] + return list(session.execute(select(Tenant).where(Tenant.id.in_(tenant_ids))).scalars().all()) + + @staticmethod + def get_tenant_name(session: Session | scoped_session, tenant_id: str) -> str | None: + """Single-column tenant name read. Used by openapi list endpoints + to denormalize ``workspace_name`` onto each row without dragging + the full ``Tenant`` ORM entity through. + """ + return session.execute(select(Tenant.name).where(Tenant.id == tenant_id)).scalar_one_or_none() + @staticmethod def find_workspace_for_account( session: Session | scoped_session, diff --git a/api/services/app_service.py b/api/services/app_service.py index b78f753364..bc867e8dc4 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,11 +1,13 @@ import json import logging +from collections.abc import Sequence from typing import Any, Literal, TypedDict, cast import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination from pydantic import BaseModel, Field from sqlalchemy import select +from sqlalchemy.orm import Session, scoped_session from configs import dify_config from constants.model_template import default_app_templates @@ -26,6 +28,7 @@ from models.tools import ApiToolProvider from services.billing_service import BillingService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService +from services.openapi.visibility import apply_openapi_gate, is_openapi_visible from services.tag_service import TagService from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task @@ -56,6 +59,51 @@ class CreateAppParams(BaseModel): class AppService: + @staticmethod + def get_app_by_id( + session: Session | scoped_session, + app_id: str, + ) -> App | None: + return session.get(App, app_id) + + @staticmethod + def get_visible_app_by_id( + session: Session | scoped_session, + app_id: str, + ) -> App | None: + app = session.get(App, app_id) + if not app or app.status != "normal" or not is_openapi_visible(app): + return None + return app + + @staticmethod + def find_visible_apps_by_ids( + session: Session | scoped_session, + app_ids: Sequence[str], + ) -> list[App]: + if not app_ids: + return [] + return list(session.execute(apply_openapi_gate(select(App).where(App.id.in_(list(app_ids))))).scalars().all()) + + @staticmethod + def find_visible_apps_by_name( + session: Session | scoped_session, + *, + name: str, + tenant_id: str, + ) -> list[App]: + return list( + session.execute( + apply_openapi_gate( + select(App).where( + App.name == name, + App.tenant_id == tenant_id, + App.status == "normal", + ) + ) + ).scalars() + ) + def get_paginate_apps(self, user_id: str, tenant_id: str, params: AppListParams) -> Pagination | None: """ Get app list with pagination diff --git a/api/services/oauth_device_flow.py b/api/services/oauth_device_flow.py index b20dc9e3f1..07b48cf4ee 100644 --- a/api/services/oauth_device_flow.py +++ b/api/services/oauth_device_flow.py @@ -12,7 +12,7 @@ from datetime import UTC, datetime, timedelta from enum import StrEnum from typing import Any, NotRequired, TypedDict -from sqlalchemy import Row, and_, func, select, update +from sqlalchemy import and_, func, select, update from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.orm import Session, scoped_session @@ -492,6 +492,7 @@ def oauth_ttl_days(tenant_id: str | None = None) -> int: return MAX_TTL_DAYS return value + def subject_match_clauses(ctx: AuthContext) -> tuple[Any, ...]: if ctx.subject_type == SubjectType.ACCOUNT: return (OAuthAccessToken.account_id == str(ctx.account_id),) @@ -509,9 +510,7 @@ def list_active_sessions( ) -> list[OAuthAccessToken]: return list( session.execute( - select( - OAuthAccessToken - ) + select(OAuthAccessToken) .where( and_( *subject_match_clauses(ctx), diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py index fa1d8d27af..6a5933da3b 100644 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py +++ b/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py @@ -52,11 +52,12 @@ def test_acl_strategy_subject_mode_matrix(ent, access_mode, subject_type, expect ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_not_called() -@patch("controllers.openapi.auth.strategies._has_tenant_membership") -def test_membership_strategy_uses_join_lookup(member): +@patch("controllers.openapi.auth.strategies.TenantService.account_belongs_to_tenant") +@patch("controllers.openapi.auth.strategies.db") +def test_membership_strategy_uses_join_lookup(db_mock, member): member.return_value = True assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True - member.assert_called_once_with("acc1", "t1") + member.assert_called_once_with(db_mock.session, "acc1", "t1") def test_membership_strategy_rejects_external_sso(): diff --git a/api/tests/unit_tests/controllers/openapi/test_device_sso.py b/api/tests/unit_tests/controllers/openapi/test_device_sso.py index 95e4466a4f..0125c583f0 100644 --- a/api/tests/unit_tests/controllers/openapi/test_device_sso.py +++ b/api/tests/unit_tests/controllers/openapi/test_device_sso.py @@ -1,7 +1,6 @@ """SSO-branch device-flow endpoints under /openapi/v1/oauth/device/.""" import builtins -from unittest.mock import MagicMock, patch import pytest from flask import Flask @@ -9,7 +8,6 @@ from flask.views import MethodView from controllers.openapi import bp as openapi_bp from controllers.openapi.oauth_device_sso import ( - _email_belongs_to_dify_account, approval_context, approve_external, sso_complete, @@ -79,27 +77,3 @@ def test_sso_complete_idp_callback_url_uses_canonical_path(): from controllers.openapi import oauth_device_sso assert oauth_device_sso._SSO_COMPLETE_PATH == "/openapi/v1/oauth/device/sso-complete" - - -@pytest.mark.parametrize( - ("email", "row", "expected"), - [ - ("alice@example.com", "acc1", True), - ("alice@example.com", None, False), - ("Alice@Example.COM", "acc1", True), # case-insensitive lookup - (" alice@example.com ", "acc1", True), # surrounding whitespace stripped - ("", "acc1", False), - (" ", "acc1", False), - ("", None, False), - ], -) -@patch("controllers.openapi.oauth_device_sso.db") -def test_email_belongs_to_dify_account(db_mock, email, row, expected): - exec_result = MagicMock() - exec_result.scalar_one_or_none.return_value = row - db_mock.session.execute.return_value = exec_result - assert _email_belongs_to_dify_account(email) is expected - if email.strip(): - db_mock.session.execute.assert_called_once() - else: - db_mock.session.execute.assert_not_called() diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 76de2936e0..5e89d9fb42 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -2021,6 +2021,44 @@ class TestSessionInjectedGetters: assert AccountService.get_account_by_id(mock_session, "missing") is None + def test_get_account_by_email_returns_scalar_or_none(self): + """Plain getter — case-sensitive equality (callers needing the + case-insensitive existence check use + :meth:`has_active_account_with_email`). + """ + mock_session = MagicMock() + sentinel = MagicMock(spec=Account) + mock_session.execute.return_value.scalar_one_or_none.return_value = sentinel + + assert AccountService.get_account_by_email(mock_session, "alice@example.com") is sentinel + + mock_session.execute.return_value.scalar_one_or_none.return_value = None + assert AccountService.get_account_by_email(mock_session, "ghost@example.com") is None + + def test_account_belongs_to_tenant_short_circuits_on_falsy_account_id(self): + """SSO bearers with no ``account_id`` (and any other falsy id) + must collapse to ``False`` without a DB round-trip — that's the + contract :class:`MembershipStrategy` relies on. + """ + mock_session = MagicMock() + + assert TenantService.account_belongs_to_tenant(mock_session, None, "tenant-1") is False + assert TenantService.account_belongs_to_tenant(mock_session, "", "tenant-1") is False + mock_session.execute.assert_not_called() + + def test_account_belongs_to_tenant_true_when_join_row_exists(self): + mock_session = MagicMock() + mock_session.execute.return_value.scalar_one_or_none.return_value = "join-id" + + assert TenantService.account_belongs_to_tenant(mock_session, "user-1", "tenant-1") is True + mock_session.execute.assert_called_once() + + def test_account_belongs_to_tenant_false_when_no_join(self): + mock_session = MagicMock() + mock_session.execute.return_value.scalar_one_or_none.return_value = None + + assert TenantService.account_belongs_to_tenant(mock_session, "user-1", "tenant-1") is False + def test_get_account_memberships_returns_join_tenant_pairs(self): """Returns whatever ``session.query(...).join(...).filter(...).all()`` produces — ordering unspecified, callers pick the default @@ -2049,6 +2087,54 @@ class TestSessionInjectedGetters: assert out == rows assert mock_session.execute.called + def test_get_tenant_by_id_is_plain_session_get(self): + """``get_tenant_by_id`` must NOT apply a status filter — the + openapi auth pipeline needs to map ``status == ARCHIVE`` to a + 403, distinct from a 404 for "missing". + """ + from models import Tenant + + mock_session = MagicMock() + sentinel = MagicMock(spec=Tenant) + mock_session.get.return_value = sentinel + + assert TenantService.get_tenant_by_id(mock_session, "tenant-1") is sentinel + mock_session.get.assert_called_once_with(Tenant, "tenant-1") + + def test_get_tenant_by_id_returns_none_when_missing(self): + mock_session = MagicMock() + mock_session.get.return_value = None + + assert TenantService.get_tenant_by_id(mock_session, "missing") is None + + def test_get_tenants_by_ids_short_circuits_on_empty_input(self): + """Empty id list must not emit ``WHERE id IN ()``.""" + mock_session = MagicMock() + + assert TenantService.get_tenants_by_ids(mock_session, []) == [] + mock_session.execute.assert_not_called() + + def test_get_tenants_by_ids_returns_scalars(self): + mock_session = MagicMock() + tenants = [MagicMock(), MagicMock()] + mock_session.execute.return_value.scalars.return_value.all.return_value = tenants + + assert TenantService.get_tenants_by_ids(mock_session, ["t1", "t2"]) == tenants + mock_session.execute.assert_called_once() + + def test_get_tenant_name_returns_scalar_or_none(self): + """Single-column lookup: ``session.execute(...).scalar_one_or_none()`` + — used by openapi list endpoints to denormalise + ``workspace_name`` onto each row. + """ + mock_session = MagicMock() + mock_session.execute.return_value.scalar_one_or_none.return_value = "Acme Inc." + + assert TenantService.get_tenant_name(mock_session, "tenant-1") == "Acme Inc." + + mock_session.execute.return_value.scalar_one_or_none.return_value = None + assert TenantService.get_tenant_name(mock_session, "missing") is None + def test_find_workspace_for_account_returns_first_row_or_none(self): """Per-id read returns ``session.execute(...).first()`` directly; callers map ``None`` → 404 to avoid leaking workspace IDs across @@ -2058,10 +2144,7 @@ class TestSessionInjectedGetters: sentinel_row = (MagicMock(), MagicMock()) mock_session.execute.return_value.first.return_value = sentinel_row - assert ( - TenantService.find_workspace_for_account(mock_session, "user-123", "ws-1") - is sentinel_row - ) + assert TenantService.find_workspace_for_account(mock_session, "user-123", "ws-1") is sentinel_row mock_session.execute.return_value.first.return_value = None assert TenantService.find_workspace_for_account(mock_session, "user-123", "ws-1") is None diff --git a/api/tests/unit_tests/services/test_app_service.py b/api/tests/unit_tests/services/test_app_service.py new file mode 100644 index 0000000000..610b32ac3c --- /dev/null +++ b/api/tests/unit_tests/services/test_app_service.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from models.model import App +from services.app_service import AppService + + +class TestOpenapiVisibilityHelpers: + """Coverage for the session-injected, openapi-visibility-scoped + ``AppService`` getters used by ``/openapi/v1/apps*``. These helpers + centralise the "row exists + status normal + openapi-visibility + gate passes" check so the controller can stay free of SQL. + """ + + def test_get_app_by_id_is_plain_session_get(self): + """``get_app_by_id`` must NOT apply status / visibility filters + — callers (e.g. the openapi auth pipeline) need to differentiate + 404 (missing) from 403 (``enable_api`` off) and would lose that + signal if the helper coalesced both into ``None``. + """ + mock_session = MagicMock() + sentinel_app = MagicMock(spec=App) + sentinel_app.status = "archived" # explicitly NOT "normal" + mock_session.get.return_value = sentinel_app + + assert AppService.get_app_by_id(mock_session, "app-uuid") is sentinel_app + mock_session.get.assert_called_once_with(App, "app-uuid") + + def test_get_app_by_id_returns_none_when_missing(self): + mock_session = MagicMock() + mock_session.get.return_value = None + + assert AppService.get_app_by_id(mock_session, "missing") is None + + def test_get_visible_app_by_id_returns_app_when_visible(self): + mock_session = MagicMock() + app = MagicMock(spec=App) + app.status = "normal" + mock_session.get.return_value = app + + with patch("services.app_service.is_openapi_visible", return_value=True): + assert AppService.get_visible_app_by_id(mock_session, "app-uuid") is app + + mock_session.get.assert_called_once_with(App, "app-uuid") + + def test_get_visible_app_by_id_returns_none_when_row_missing(self): + mock_session = MagicMock() + mock_session.get.return_value = None + + assert AppService.get_visible_app_by_id(mock_session, "missing") is None + + def test_get_visible_app_by_id_returns_none_when_status_not_normal(self): + """Soft-deleted/archived rows must not surface on the openapi + surface — the helper hides them by returning ``None``. + """ + mock_session = MagicMock() + app = MagicMock(spec=App) + app.status = "archived" + mock_session.get.return_value = app + + with patch("services.app_service.is_openapi_visible", return_value=True): + assert AppService.get_visible_app_by_id(mock_session, "app-uuid") is None + + def test_get_visible_app_by_id_returns_none_when_visibility_gate_rejects(self): + """``is_openapi_visible`` is the per-row counterpart to + ``apply_openapi_gate`` — when it returns False the helper must + treat the row as invisible (not "found but unauthorized"). + """ + mock_session = MagicMock() + app = MagicMock(spec=App) + app.status = "normal" + mock_session.get.return_value = app + + with patch("services.app_service.is_openapi_visible", return_value=False): + assert AppService.get_visible_app_by_id(mock_session, "app-uuid") is None + + def test_find_visible_apps_by_name_returns_scalars_through_visibility_gate(self): + """Tenant-scoped name lookup. The helper passes the SELECT through + ``apply_openapi_gate`` and materialises ``.scalars()`` into a list + so the controller can branch on length (404 / single / 409). + """ + mock_session = MagicMock() + rows = [MagicMock(spec=App), MagicMock(spec=App)] + mock_session.execute.return_value.scalars.return_value = iter(rows) + + with patch("services.app_service.apply_openapi_gate", side_effect=lambda q: q) as gate: + out = AppService.find_visible_apps_by_name(mock_session, name="my-app", tenant_id="tenant-1") + + assert out == rows + # Visibility gate must wrap the SELECT exactly once. + gate.assert_called_once() + mock_session.execute.assert_called_once() + + def test_find_visible_apps_by_name_returns_empty_list_on_no_match(self): + mock_session = MagicMock() + mock_session.execute.return_value.scalars.return_value = iter([]) + + with patch("services.app_service.apply_openapi_gate", side_effect=lambda q: q): + out = AppService.find_visible_apps_by_name(mock_session, name="nope", tenant_id="tenant-1") + + assert out == [] + + def test_find_visible_apps_by_ids_short_circuits_on_empty_input(self): + """Empty id list must not emit ``WHERE id IN ()`` — Postgres + rejects empty IN lists and the call is a guaranteed no-op + anyway. The helper returns ``[]`` without touching the session. + """ + mock_session = MagicMock() + + assert AppService.find_visible_apps_by_ids(mock_session, []) == [] + mock_session.execute.assert_not_called() + + def test_find_visible_apps_by_ids_passes_through_visibility_gate(self): + """Bulk fetch routes through ``apply_openapi_gate`` exactly once + and materialises the scalar rows. **No** status filter is + applied here — the EE permitted-external pipeline filters + non-normal hits in Python so its page count stays anchored. + """ + mock_session = MagicMock() + rows = [MagicMock(spec=App), MagicMock(spec=App)] + mock_session.execute.return_value.scalars.return_value.all.return_value = rows + + with patch("services.app_service.apply_openapi_gate", side_effect=lambda q: q) as gate: + out = AppService.find_visible_apps_by_ids(mock_session, ["a", "b"]) + + assert out == rows + gate.assert_called_once() + mock_session.execute.assert_called_once()