refactor: move db query from api leyer to service layer

This commit is contained in:
yunlu.wen
2026-05-23 14:11:52 +08:00
parent 95816a26b8
commit 98de360447
14 changed files with 380 additions and 138 deletions

View File

@ -62,9 +62,7 @@ class AccountApi(Resource):
).model_dump(mode="json")
account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None
memberships = (
TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
)
memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
default_ws_id = _pick_default_workspace(memberships)
return AccountResponse(

View File

@ -10,7 +10,6 @@ from __future__ import annotations
import uuid as _uuid
from typing import Any, cast
import sqlalchemy as sa
from flask import request
from flask_restx import Resource
from pydantic import ValidationError
@ -43,9 +42,9 @@ from libs.oauth_bearer import (
require_workspace_member,
validate_bearer,
)
from models import App, Tenant
from models import App
from services.account_service import TenantService
from services.app_service import AppListParams, AppService
from services.openapi.visibility import apply_openapi_gate, is_openapi_visible
from services.tag_service import TagService
_APPS_READ_DECORATORS = [
@ -82,23 +81,14 @@ class AppReadResource(Resource):
is_uuid = False
if is_uuid:
app = db.session.get(App, str(parsed_uuid)) # normalised dashed form
if not app or app.status != "normal" or not is_openapi_visible(app):
# ``str(parsed_uuid)`` normalises to the canonical dashed form.
app = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
if app is None:
raise NotFound("app not found")
else:
if not workspace_id:
raise UnprocessableEntity("workspace_id is required for name-based lookup")
matches = list(
db.session.execute(
apply_openapi_gate(
sa.select(App).where(
App.name == app_id,
App.tenant_id == workspace_id,
App.status == "normal",
)
)
).scalars()
)
matches = AppService.find_visible_apps_by_name(db.session, name=app_id, tenant_id=workspace_id)
if len(matches) == 0:
raise NotFound("app not found")
if len(matches) > 1:
@ -210,12 +200,10 @@ class AppListApi(Resource):
tenant_name: str | None = None
if parsed_uuid is not None:
app: App | None = db.session.get(App, str(parsed_uuid))
if not app or app.status != "normal" or str(app.tenant_id) != workspace_id or not is_openapi_visible(app):
app: App | None = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
if app is None or str(app.tenant_id) != workspace_id:
return empty
tenant_name = db.session.execute(
sa.select(Tenant.name).where(Tenant.id == workspace_id)
).scalar_one_or_none()
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
item = AppListRow(
id=str(app.id),
name=app.name,
@ -255,9 +243,7 @@ class AppListApi(Resource):
tenant_name = None
if pagination.items:
tenant_name = db.session.execute(
sa.select(Tenant.name).where(Tenant.id == workspace_id)
).scalar_one_or_none()
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
items = [
AppListRow(

View File

@ -7,7 +7,6 @@ EE blueprint chain so this module is unreachable there.
from __future__ import annotations
import sqlalchemy as sa
from flask import request
from flask_restx import Resource
from pydantic import ValidationError
@ -29,10 +28,11 @@ from libs.oauth_bearer import (
require_scope,
validate_bearer,
)
from models import App, Tenant
from models import App
from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.app_permitted_service import list_permitted_apps
from services.openapi.license_gate import license_required
from services.openapi.visibility import apply_openapi_gate
@openapi_ns.route("/permitted-external-apps")
@ -68,15 +68,10 @@ class PermittedExternalAppsListApi(Resource):
return env.model_dump(mode="json"), 200
apps_by_id: dict[str, App] = {
str(a.id): a
for a in db.session.execute(apply_openapi_gate(sa.select(App).where(App.id.in_(page_result.app_ids))))
.scalars()
.all()
}
tenant_ids = list({a.tenant_id for a in apps_by_id.values()})
tenants_by_id = {
str(t.id): t for t in db.session.execute(sa.select(Tenant).where(Tenant.id.in_(tenant_ids))).scalars().all()
str(a.id): a for a in AppService.find_visible_apps_by_ids(db.session, page_result.app_ids)
}
tenant_ids = list({str(a.tenant_id) for a in apps_by_id.values()})
tenants_by_id = {str(t.id): t for t in TenantService.get_tenants_by_ids(db.session, tenant_ids)}
items: list[AppListRow] = []
for app_id in page_result.app_ids:

View File

@ -30,7 +30,9 @@ from libs.oauth_bearer import (
get_authenticator,
set_auth_ctx,
)
from models import App, Tenant, TenantStatus
from models import TenantStatus
from services.account_service import TenantService
from services.app_service import AppService
class BearerCheck:
@ -97,12 +99,12 @@ class AppResolver:
app_id = ctx.path_params.get("app_id")
if not app_id:
raise BadRequest("app_id is required in path")
app = db.session.get(App, app_id)
app = AppService.get_app_by_id(db.session, app_id)
if not app or app.status != "normal":
raise NotFound("app not found")
if not app.enable_api:
raise Forbidden("service_api_disabled")
tenant = db.session.get(Tenant, app.tenant_id)
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("workspace unavailable")
ctx.app, ctx.tenant = app, tenant

View File

@ -7,18 +7,16 @@ composition stays a flat list.
from __future__ import annotations
import uuid
from typing import Protocol
from flask import current_app
from flask_login import user_logged_in
from sqlalchemy import select
from controllers.openapi.auth.context import Context
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.oauth_bearer import SubjectType
from models import Account, TenantAccountJoin
from services.account_service import AccountService, TenantService
from services.end_user_service import EndUserService
from services.enterprise.enterprise_service import (
EnterpriseService,
@ -106,9 +104,7 @@ class AclStrategy:
return str(ctx.account_id) if ctx.account_id is not None else None
if ctx.subject_email is None:
return None
account = db.session.execute(
select(Account).where(Account.email == ctx.subject_email),
).scalar_one_or_none()
account = AccountService.get_account_by_email(db.session, ctx.subject_email)
return str(account.id) if account is not None else None
@ -125,19 +121,7 @@ class MembershipStrategy:
return False
if ctx.tenant is None:
return False
return _has_tenant_membership(ctx.account_id, ctx.tenant.id)
def _has_tenant_membership(account_id: uuid.UUID | str | None, tenant_id: str) -> bool:
if not account_id:
return False
row = db.session.execute(
select(TenantAccountJoin.id).where(
TenantAccountJoin.tenant_id == tenant_id,
TenantAccountJoin.account_id == account_id,
)
).scalar_one_or_none()
return row is not None
return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id)
def _login_as(user) -> None:
@ -159,7 +143,7 @@ class AccountMounter:
def mount(self, ctx: Context) -> None:
if ctx.account_id is None:
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
account = db.session.get(Account, ctx.account_id)
account = AccountService.get_account_by_id(db.session, str(ctx.account_id))
if account is None:
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
account.current_tenant = ctx.must_tenant

View File

@ -50,6 +50,7 @@ from libs.rate_limit import (
LIMIT_LOOKUP_PUBLIC,
rate_limit,
)
from services.account_service import TenantService
from services.oauth_device_flow import (
ACCOUNT_ISSUER_SENTINEL,
DEFAULT_POLL_INTERVAL_SECONDS,
@ -333,18 +334,7 @@ def _audit_cross_ip_if_needed(state) -> None:
def _build_account_poll_payload(account, tenant, mint) -> PollPayload:
"""Account branch of the shared `PollPayload` contract. SSO-only fields
(`subject_email`, `subject_issuer`) are intentionally omitted; see the
`PollPayload` docstring in `services.oauth_device_flow`.
"""
from models import Tenant, TenantAccountJoin
rows = (
db.session.query(Tenant, TenantAccountJoin)
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.account_id == account.id)
.all()
)
rows = TenantService.get_workspaces_for_account(db.session, str(account.id))
workspaces = [WorkspacePayload(id=str(t.id), name=t.name, role=getattr(m, "role", "")) for t, m in rows]
# Prefer active session tenant → DB-flagged current join → first membership.
default_ws_id = None

View File

@ -17,7 +17,6 @@ import secrets
from dataclasses import dataclass
from flask import jsonify, make_response, redirect, request
from sqlalchemy import func, select
from werkzeug.exceptions import (
BadGateway,
BadRequest,
@ -49,8 +48,7 @@ from libs.rate_limit import (
enforce,
rate_limit,
)
from models import Account
from models.account import AccountStatus
from services.account_service import AccountService
from services.enterprise.enterprise_service import EnterpriseService
from services.oauth_device_flow import (
DeviceFlowRedis,
@ -149,7 +147,7 @@ def sso_complete():
if state.status is not DeviceFlowStatus.PENDING:
raise Conflict("user_code_not_pending")
if _email_belongs_to_dify_account(claims["email"]):
if AccountService.has_active_account_with_email(db.session, claims["email"]):
_emit_external_rejection_audit(
state,
_RejectedClaims(subject_email=claims["email"], subject_issuer=claims["issuer"]),
@ -229,7 +227,7 @@ def approve_external():
if state.status is not DeviceFlowStatus.PENDING:
raise Conflict("user_code_not_pending")
if _email_belongs_to_dify_account(claims.subject_email):
if AccountService.has_active_account_with_email(db.session, claims.subject_email):
_emit_external_rejection_audit(state, claims, reason="email_belongs_to_dify_account")
raise Forbidden("email_belongs_to_dify_account")
@ -308,29 +306,6 @@ class _RejectedClaims:
subject_issuer: str
def _email_belongs_to_dify_account(email: str) -> bool:
"""External SSO subjects whose email matches an active Dify Account must
authenticate via the internal Dify login path (which mints dfoa_), not via
the external SSO device flow. Returning True here blocks dfoe_ minting.
Pending/uninitialized/banned/closed accounts do not block: pending and
uninitialized users may complete invitation via SSO; banned and closed
accounts are handled by separate enforcement paths.
"""
if not email:
return False
normalized = email.strip().lower()
if not normalized:
return False
row = db.session.execute(
select(Account.id).where(
func.lower(Account.email) == normalized,
Account.status == AccountStatus.ACTIVE,
),
).scalar_one_or_none()
return row is not None
def _emit_external_rejection_audit(state, claims, *, reason: str) -> None:
logger.warning(
"audit: oauth.device_flow_rejected subject_type=%s subject_email=%s subject_issuer=%s reason=%s",

View File

@ -164,6 +164,29 @@ class AccountService:
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
@staticmethod
def get_account_by_email(session: Session | scoped_session, email: str) -> Account | None:
"""Plain ``Account`` getter keyed by email. Case-sensitive — use
:meth:`has_active_account_with_email` for the case-insensitive
existence check that backs the SSO collision rule.
"""
return session.execute(select(Account).where(Account.email == email)).scalar_one_or_none()
@staticmethod
def has_active_account_with_email(session: Session | scoped_session, email: str) -> bool:
if not email:
return False
normalized = email.strip().lower()
if not normalized:
return False
row = session.execute(
select(Account.id).where(
func.lower(Account.email) == normalized,
Account.status == AccountStatus.ACTIVE,
)
).scalar_one_or_none()
return row is not None
@staticmethod
def get_account_by_id(session: Session | scoped_session, account_id: str) -> Account | None:
"""Plain ``Account`` getter — no banned check, no tenant rotation,
@ -1241,6 +1264,61 @@ class TenantService:
).all()
)
@staticmethod
def account_belongs_to_tenant(
session: Session | scoped_session,
account_id: uuid.UUID | str | None,
tenant_id: str,
) -> bool:
"""Existence check for ``TenantAccountJoin(account_id, tenant_id)``.
Backs the CE-deployment membership fallback in
``controllers.openapi.auth.strategies.MembershipStrategy``.
``None``/empty ``account_id`` short-circuits to ``False`` so SSO
bearers (no account) and missing identity collapse cleanly.
"""
if not account_id:
return False
row = session.execute(
select(TenantAccountJoin.id).where(
TenantAccountJoin.tenant_id == tenant_id,
TenantAccountJoin.account_id == account_id,
)
).scalar_one_or_none()
return row is not None
@staticmethod
def get_tenant_by_id(session: Session | scoped_session, tenant_id: str) -> Tenant | None:
"""Plain ``session.get(Tenant, tenant_id)`` — no status filter.
Callers map ``status == ARCHIVE`` to their own error code (the
openapi auth pipeline raises 403 ``workspace unavailable``).
"""
return session.get(Tenant, tenant_id)
@staticmethod
def get_tenants_by_ids(
session: Session | scoped_session,
tenant_ids: list[str],
) -> list[Tenant]:
"""Bulk ``Tenant`` fetch by primary-key list. Order is unspecified
— callers index by ``tenant.id`` (e.g. for cross-tenant denorm
in ``/openapi/v1/permitted-external-apps``).
Empty input short-circuits to ``[]`` to avoid emitting an
``IN ()`` SQL fragment.
"""
if not tenant_ids:
return []
return list(session.execute(select(Tenant).where(Tenant.id.in_(tenant_ids))).scalars().all())
@staticmethod
def get_tenant_name(session: Session | scoped_session, tenant_id: str) -> str | None:
"""Single-column tenant name read. Used by openapi list endpoints
to denormalize ``workspace_name`` onto each row without dragging
the full ``Tenant`` ORM entity through.
"""
return session.execute(select(Tenant.name).where(Tenant.id == tenant_id)).scalar_one_or_none()
@staticmethod
def find_workspace_for_account(
session: Session | scoped_session,

View File

@ -1,11 +1,13 @@
import json
import logging
from collections.abc import Sequence
from typing import Any, Literal, TypedDict, cast
import sqlalchemy as sa
from flask_sqlalchemy.pagination import Pagination
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session, scoped_session
from configs import dify_config
from constants.model_template import default_app_templates
@ -26,6 +28,7 @@ from models.tools import ApiToolProvider
from services.billing_service import BillingService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.openapi.visibility import apply_openapi_gate, is_openapi_visible
from services.tag_service import TagService
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
@ -56,6 +59,51 @@ class CreateAppParams(BaseModel):
class AppService:
@staticmethod
def get_app_by_id(
session: Session | scoped_session,
app_id: str,
) -> App | None:
return session.get(App, app_id)
@staticmethod
def get_visible_app_by_id(
session: Session | scoped_session,
app_id: str,
) -> App | None:
app = session.get(App, app_id)
if not app or app.status != "normal" or not is_openapi_visible(app):
return None
return app
@staticmethod
def find_visible_apps_by_ids(
session: Session | scoped_session,
app_ids: Sequence[str],
) -> list[App]:
if not app_ids:
return []
return list(session.execute(apply_openapi_gate(select(App).where(App.id.in_(list(app_ids))))).scalars().all())
@staticmethod
def find_visible_apps_by_name(
session: Session | scoped_session,
*,
name: str,
tenant_id: str,
) -> list[App]:
return list(
session.execute(
apply_openapi_gate(
select(App).where(
App.name == name,
App.tenant_id == tenant_id,
App.status == "normal",
)
)
).scalars()
)
def get_paginate_apps(self, user_id: str, tenant_id: str, params: AppListParams) -> Pagination | None:
"""
Get app list with pagination

View File

@ -12,7 +12,7 @@ from datetime import UTC, datetime, timedelta
from enum import StrEnum
from typing import Any, NotRequired, TypedDict
from sqlalchemy import Row, and_, func, select, update
from sqlalchemy import and_, func, select, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session, scoped_session
@ -492,6 +492,7 @@ def oauth_ttl_days(tenant_id: str | None = None) -> int:
return MAX_TTL_DAYS
return value
def subject_match_clauses(ctx: AuthContext) -> tuple[Any, ...]:
if ctx.subject_type == SubjectType.ACCOUNT:
return (OAuthAccessToken.account_id == str(ctx.account_id),)
@ -509,9 +510,7 @@ def list_active_sessions(
) -> list[OAuthAccessToken]:
return list(
session.execute(
select(
OAuthAccessToken
)
select(OAuthAccessToken)
.where(
and_(
*subject_match_clauses(ctx),

View File

@ -52,11 +52,12 @@ def test_acl_strategy_subject_mode_matrix(ent, access_mode, subject_type, expect
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_not_called()
@patch("controllers.openapi.auth.strategies._has_tenant_membership")
def test_membership_strategy_uses_join_lookup(member):
@patch("controllers.openapi.auth.strategies.TenantService.account_belongs_to_tenant")
@patch("controllers.openapi.auth.strategies.db")
def test_membership_strategy_uses_join_lookup(db_mock, member):
member.return_value = True
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
member.assert_called_once_with("acc1", "t1")
member.assert_called_once_with(db_mock.session, "acc1", "t1")
def test_membership_strategy_rejects_external_sso():

View File

@ -1,7 +1,6 @@
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/."""
import builtins
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
@ -9,7 +8,6 @@ from flask.views import MethodView
from controllers.openapi import bp as openapi_bp
from controllers.openapi.oauth_device_sso import (
_email_belongs_to_dify_account,
approval_context,
approve_external,
sso_complete,
@ -79,27 +77,3 @@ def test_sso_complete_idp_callback_url_uses_canonical_path():
from controllers.openapi import oauth_device_sso
assert oauth_device_sso._SSO_COMPLETE_PATH == "/openapi/v1/oauth/device/sso-complete"
@pytest.mark.parametrize(
("email", "row", "expected"),
[
("alice@example.com", "acc1", True),
("alice@example.com", None, False),
("Alice@Example.COM", "acc1", True), # case-insensitive lookup
(" alice@example.com ", "acc1", True), # surrounding whitespace stripped
("", "acc1", False),
(" ", "acc1", False),
("", None, False),
],
)
@patch("controllers.openapi.oauth_device_sso.db")
def test_email_belongs_to_dify_account(db_mock, email, row, expected):
exec_result = MagicMock()
exec_result.scalar_one_or_none.return_value = row
db_mock.session.execute.return_value = exec_result
assert _email_belongs_to_dify_account(email) is expected
if email.strip():
db_mock.session.execute.assert_called_once()
else:
db_mock.session.execute.assert_not_called()

View File

@ -2021,6 +2021,44 @@ class TestSessionInjectedGetters:
assert AccountService.get_account_by_id(mock_session, "missing") is None
def test_get_account_by_email_returns_scalar_or_none(self):
"""Plain getter — case-sensitive equality (callers needing the
case-insensitive existence check use
:meth:`has_active_account_with_email`).
"""
mock_session = MagicMock()
sentinel = MagicMock(spec=Account)
mock_session.execute.return_value.scalar_one_or_none.return_value = sentinel
assert AccountService.get_account_by_email(mock_session, "alice@example.com") is sentinel
mock_session.execute.return_value.scalar_one_or_none.return_value = None
assert AccountService.get_account_by_email(mock_session, "ghost@example.com") is None
def test_account_belongs_to_tenant_short_circuits_on_falsy_account_id(self):
"""SSO bearers with no ``account_id`` (and any other falsy id)
must collapse to ``False`` without a DB round-trip — that's the
contract :class:`MembershipStrategy` relies on.
"""
mock_session = MagicMock()
assert TenantService.account_belongs_to_tenant(mock_session, None, "tenant-1") is False
assert TenantService.account_belongs_to_tenant(mock_session, "", "tenant-1") is False
mock_session.execute.assert_not_called()
def test_account_belongs_to_tenant_true_when_join_row_exists(self):
mock_session = MagicMock()
mock_session.execute.return_value.scalar_one_or_none.return_value = "join-id"
assert TenantService.account_belongs_to_tenant(mock_session, "user-1", "tenant-1") is True
mock_session.execute.assert_called_once()
def test_account_belongs_to_tenant_false_when_no_join(self):
mock_session = MagicMock()
mock_session.execute.return_value.scalar_one_or_none.return_value = None
assert TenantService.account_belongs_to_tenant(mock_session, "user-1", "tenant-1") is False
def test_get_account_memberships_returns_join_tenant_pairs(self):
"""Returns whatever ``session.query(...).join(...).filter(...).all()``
produces — ordering unspecified, callers pick the default
@ -2049,6 +2087,54 @@ class TestSessionInjectedGetters:
assert out == rows
assert mock_session.execute.called
def test_get_tenant_by_id_is_plain_session_get(self):
"""``get_tenant_by_id`` must NOT apply a status filter — the
openapi auth pipeline needs to map ``status == ARCHIVE`` to a
403, distinct from a 404 for "missing".
"""
from models import Tenant
mock_session = MagicMock()
sentinel = MagicMock(spec=Tenant)
mock_session.get.return_value = sentinel
assert TenantService.get_tenant_by_id(mock_session, "tenant-1") is sentinel
mock_session.get.assert_called_once_with(Tenant, "tenant-1")
def test_get_tenant_by_id_returns_none_when_missing(self):
mock_session = MagicMock()
mock_session.get.return_value = None
assert TenantService.get_tenant_by_id(mock_session, "missing") is None
def test_get_tenants_by_ids_short_circuits_on_empty_input(self):
"""Empty id list must not emit ``WHERE id IN ()``."""
mock_session = MagicMock()
assert TenantService.get_tenants_by_ids(mock_session, []) == []
mock_session.execute.assert_not_called()
def test_get_tenants_by_ids_returns_scalars(self):
mock_session = MagicMock()
tenants = [MagicMock(), MagicMock()]
mock_session.execute.return_value.scalars.return_value.all.return_value = tenants
assert TenantService.get_tenants_by_ids(mock_session, ["t1", "t2"]) == tenants
mock_session.execute.assert_called_once()
def test_get_tenant_name_returns_scalar_or_none(self):
"""Single-column lookup: ``session.execute(...).scalar_one_or_none()``
— used by openapi list endpoints to denormalise
``workspace_name`` onto each row.
"""
mock_session = MagicMock()
mock_session.execute.return_value.scalar_one_or_none.return_value = "Acme Inc."
assert TenantService.get_tenant_name(mock_session, "tenant-1") == "Acme Inc."
mock_session.execute.return_value.scalar_one_or_none.return_value = None
assert TenantService.get_tenant_name(mock_session, "missing") is None
def test_find_workspace_for_account_returns_first_row_or_none(self):
"""Per-id read returns ``session.execute(...).first()`` directly;
callers map ``None`` → 404 to avoid leaking workspace IDs across
@ -2058,10 +2144,7 @@ class TestSessionInjectedGetters:
sentinel_row = (MagicMock(), MagicMock())
mock_session.execute.return_value.first.return_value = sentinel_row
assert (
TenantService.find_workspace_for_account(mock_session, "user-123", "ws-1")
is sentinel_row
)
assert TenantService.find_workspace_for_account(mock_session, "user-123", "ws-1") is sentinel_row
mock_session.execute.return_value.first.return_value = None
assert TenantService.find_workspace_for_account(mock_session, "user-123", "ws-1") is None

View File

@ -0,0 +1,129 @@
from __future__ import annotations
from unittest.mock import MagicMock, patch
from models.model import App
from services.app_service import AppService
class TestOpenapiVisibilityHelpers:
"""Coverage for the session-injected, openapi-visibility-scoped
``AppService`` getters used by ``/openapi/v1/apps*``. These helpers
centralise the "row exists + status normal + openapi-visibility
gate passes" check so the controller can stay free of SQL.
"""
def test_get_app_by_id_is_plain_session_get(self):
"""``get_app_by_id`` must NOT apply status / visibility filters
— callers (e.g. the openapi auth pipeline) need to differentiate
404 (missing) from 403 (``enable_api`` off) and would lose that
signal if the helper coalesced both into ``None``.
"""
mock_session = MagicMock()
sentinel_app = MagicMock(spec=App)
sentinel_app.status = "archived" # explicitly NOT "normal"
mock_session.get.return_value = sentinel_app
assert AppService.get_app_by_id(mock_session, "app-uuid") is sentinel_app
mock_session.get.assert_called_once_with(App, "app-uuid")
def test_get_app_by_id_returns_none_when_missing(self):
mock_session = MagicMock()
mock_session.get.return_value = None
assert AppService.get_app_by_id(mock_session, "missing") is None
def test_get_visible_app_by_id_returns_app_when_visible(self):
mock_session = MagicMock()
app = MagicMock(spec=App)
app.status = "normal"
mock_session.get.return_value = app
with patch("services.app_service.is_openapi_visible", return_value=True):
assert AppService.get_visible_app_by_id(mock_session, "app-uuid") is app
mock_session.get.assert_called_once_with(App, "app-uuid")
def test_get_visible_app_by_id_returns_none_when_row_missing(self):
mock_session = MagicMock()
mock_session.get.return_value = None
assert AppService.get_visible_app_by_id(mock_session, "missing") is None
def test_get_visible_app_by_id_returns_none_when_status_not_normal(self):
"""Soft-deleted/archived rows must not surface on the openapi
surface — the helper hides them by returning ``None``.
"""
mock_session = MagicMock()
app = MagicMock(spec=App)
app.status = "archived"
mock_session.get.return_value = app
with patch("services.app_service.is_openapi_visible", return_value=True):
assert AppService.get_visible_app_by_id(mock_session, "app-uuid") is None
def test_get_visible_app_by_id_returns_none_when_visibility_gate_rejects(self):
"""``is_openapi_visible`` is the per-row counterpart to
``apply_openapi_gate`` — when it returns False the helper must
treat the row as invisible (not "found but unauthorized").
"""
mock_session = MagicMock()
app = MagicMock(spec=App)
app.status = "normal"
mock_session.get.return_value = app
with patch("services.app_service.is_openapi_visible", return_value=False):
assert AppService.get_visible_app_by_id(mock_session, "app-uuid") is None
def test_find_visible_apps_by_name_returns_scalars_through_visibility_gate(self):
"""Tenant-scoped name lookup. The helper passes the SELECT through
``apply_openapi_gate`` and materialises ``.scalars()`` into a list
so the controller can branch on length (404 / single / 409).
"""
mock_session = MagicMock()
rows = [MagicMock(spec=App), MagicMock(spec=App)]
mock_session.execute.return_value.scalars.return_value = iter(rows)
with patch("services.app_service.apply_openapi_gate", side_effect=lambda q: q) as gate:
out = AppService.find_visible_apps_by_name(mock_session, name="my-app", tenant_id="tenant-1")
assert out == rows
# Visibility gate must wrap the SELECT exactly once.
gate.assert_called_once()
mock_session.execute.assert_called_once()
def test_find_visible_apps_by_name_returns_empty_list_on_no_match(self):
mock_session = MagicMock()
mock_session.execute.return_value.scalars.return_value = iter([])
with patch("services.app_service.apply_openapi_gate", side_effect=lambda q: q):
out = AppService.find_visible_apps_by_name(mock_session, name="nope", tenant_id="tenant-1")
assert out == []
def test_find_visible_apps_by_ids_short_circuits_on_empty_input(self):
"""Empty id list must not emit ``WHERE id IN ()`` — Postgres
rejects empty IN lists and the call is a guaranteed no-op
anyway. The helper returns ``[]`` without touching the session.
"""
mock_session = MagicMock()
assert AppService.find_visible_apps_by_ids(mock_session, []) == []
mock_session.execute.assert_not_called()
def test_find_visible_apps_by_ids_passes_through_visibility_gate(self):
"""Bulk fetch routes through ``apply_openapi_gate`` exactly once
and materialises the scalar rows. **No** status filter is
applied here — the EE permitted-external pipeline filters
non-normal hits in Python so its page count stays anchored.
"""
mock_session = MagicMock()
rows = [MagicMock(spec=App), MagicMock(spec=App)]
mock_session.execute.return_value.scalars.return_value.all.return_value = rows
with patch("services.app_service.apply_openapi_gate", side_effect=lambda q: q) as gate:
out = AppService.find_visible_apps_by_ids(mock_session, ["a", "b"])
assert out == rows
gate.assert_called_once()
mock_session.execute.assert_called_once()