mirror of
https://github.com/langgenius/dify.git
synced 2026-05-28 04:43:33 +08:00
refactor: move db query from api leyer to service layer
This commit is contained in:
@ -62,9 +62,7 @@ class AccountApi(Resource):
|
||||
).model_dump(mode="json")
|
||||
|
||||
account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None
|
||||
memberships = (
|
||||
TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
|
||||
)
|
||||
memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
|
||||
default_ws_id = _pick_default_workspace(memberships)
|
||||
|
||||
return AccountResponse(
|
||||
|
||||
@ -10,7 +10,6 @@ from __future__ import annotations
|
||||
import uuid as _uuid
|
||||
from typing import Any, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
@ -43,9 +42,9 @@ from libs.oauth_bearer import (
|
||||
require_workspace_member,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App, Tenant
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppListParams, AppService
|
||||
from services.openapi.visibility import apply_openapi_gate, is_openapi_visible
|
||||
from services.tag_service import TagService
|
||||
|
||||
_APPS_READ_DECORATORS = [
|
||||
@ -82,23 +81,14 @@ class AppReadResource(Resource):
|
||||
is_uuid = False
|
||||
|
||||
if is_uuid:
|
||||
app = db.session.get(App, str(parsed_uuid)) # normalised dashed form
|
||||
if not app or app.status != "normal" or not is_openapi_visible(app):
|
||||
# ``str(parsed_uuid)`` normalises to the canonical dashed form.
|
||||
app = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
|
||||
if app is None:
|
||||
raise NotFound("app not found")
|
||||
else:
|
||||
if not workspace_id:
|
||||
raise UnprocessableEntity("workspace_id is required for name-based lookup")
|
||||
matches = list(
|
||||
db.session.execute(
|
||||
apply_openapi_gate(
|
||||
sa.select(App).where(
|
||||
App.name == app_id,
|
||||
App.tenant_id == workspace_id,
|
||||
App.status == "normal",
|
||||
)
|
||||
)
|
||||
).scalars()
|
||||
)
|
||||
matches = AppService.find_visible_apps_by_name(db.session, name=app_id, tenant_id=workspace_id)
|
||||
if len(matches) == 0:
|
||||
raise NotFound("app not found")
|
||||
if len(matches) > 1:
|
||||
@ -210,12 +200,10 @@ class AppListApi(Resource):
|
||||
|
||||
tenant_name: str | None = None
|
||||
if parsed_uuid is not None:
|
||||
app: App | None = db.session.get(App, str(parsed_uuid))
|
||||
if not app or app.status != "normal" or str(app.tenant_id) != workspace_id or not is_openapi_visible(app):
|
||||
app: App | None = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
|
||||
if app is None or str(app.tenant_id) != workspace_id:
|
||||
return empty
|
||||
tenant_name = db.session.execute(
|
||||
sa.select(Tenant.name).where(Tenant.id == workspace_id)
|
||||
).scalar_one_or_none()
|
||||
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
|
||||
item = AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
@ -255,9 +243,7 @@ class AppListApi(Resource):
|
||||
|
||||
tenant_name = None
|
||||
if pagination.items:
|
||||
tenant_name = db.session.execute(
|
||||
sa.select(Tenant.name).where(Tenant.id == workspace_id)
|
||||
).scalar_one_or_none()
|
||||
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
|
||||
|
||||
items = [
|
||||
AppListRow(
|
||||
|
||||
@ -7,7 +7,6 @@ EE blueprint chain so this module is unreachable there.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
@ -29,10 +28,11 @@ from libs.oauth_bearer import (
|
||||
require_scope,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App, Tenant
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.app_permitted_service import list_permitted_apps
|
||||
from services.openapi.license_gate import license_required
|
||||
from services.openapi.visibility import apply_openapi_gate
|
||||
|
||||
|
||||
@openapi_ns.route("/permitted-external-apps")
|
||||
@ -68,15 +68,10 @@ class PermittedExternalAppsListApi(Resource):
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
apps_by_id: dict[str, App] = {
|
||||
str(a.id): a
|
||||
for a in db.session.execute(apply_openapi_gate(sa.select(App).where(App.id.in_(page_result.app_ids))))
|
||||
.scalars()
|
||||
.all()
|
||||
}
|
||||
tenant_ids = list({a.tenant_id for a in apps_by_id.values()})
|
||||
tenants_by_id = {
|
||||
str(t.id): t for t in db.session.execute(sa.select(Tenant).where(Tenant.id.in_(tenant_ids))).scalars().all()
|
||||
str(a.id): a for a in AppService.find_visible_apps_by_ids(db.session, page_result.app_ids)
|
||||
}
|
||||
tenant_ids = list({str(a.tenant_id) for a in apps_by_id.values()})
|
||||
tenants_by_id = {str(t.id): t for t in TenantService.get_tenants_by_ids(db.session, tenant_ids)}
|
||||
|
||||
items: list[AppListRow] = []
|
||||
for app_id in page_result.app_ids:
|
||||
|
||||
@ -30,7 +30,9 @@ from libs.oauth_bearer import (
|
||||
get_authenticator,
|
||||
set_auth_ctx,
|
||||
)
|
||||
from models import App, Tenant, TenantStatus
|
||||
from models import TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
class BearerCheck:
|
||||
@ -97,12 +99,12 @@ class AppResolver:
|
||||
app_id = ctx.path_params.get("app_id")
|
||||
if not app_id:
|
||||
raise BadRequest("app_id is required in path")
|
||||
app = db.session.get(App, app_id)
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
tenant = db.session.get(Tenant, app.tenant_id)
|
||||
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
ctx.app, ctx.tenant = app, tenant
|
||||
|
||||
@ -7,18 +7,16 @@ composition stays a flat list.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Protocol
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import user_logged_in
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from models import Account, TenantAccountJoin
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import (
|
||||
EnterpriseService,
|
||||
@ -106,9 +104,7 @@ class AclStrategy:
|
||||
return str(ctx.account_id) if ctx.account_id is not None else None
|
||||
if ctx.subject_email is None:
|
||||
return None
|
||||
account = db.session.execute(
|
||||
select(Account).where(Account.email == ctx.subject_email),
|
||||
).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email(db.session, ctx.subject_email)
|
||||
return str(account.id) if account is not None else None
|
||||
|
||||
|
||||
@ -125,19 +121,7 @@ class MembershipStrategy:
|
||||
return False
|
||||
if ctx.tenant is None:
|
||||
return False
|
||||
return _has_tenant_membership(ctx.account_id, ctx.tenant.id)
|
||||
|
||||
|
||||
def _has_tenant_membership(account_id: uuid.UUID | str | None, tenant_id: str) -> bool:
|
||||
if not account_id:
|
||||
return False
|
||||
row = db.session.execute(
|
||||
select(TenantAccountJoin.id).where(
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
TenantAccountJoin.account_id == account_id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
return row is not None
|
||||
return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id)
|
||||
|
||||
|
||||
def _login_as(user) -> None:
|
||||
@ -159,7 +143,7 @@ class AccountMounter:
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.account_id is None:
|
||||
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
|
||||
account = db.session.get(Account, ctx.account_id)
|
||||
account = AccountService.get_account_by_id(db.session, str(ctx.account_id))
|
||||
if account is None:
|
||||
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
|
||||
account.current_tenant = ctx.must_tenant
|
||||
|
||||
@ -50,6 +50,7 @@ from libs.rate_limit import (
|
||||
LIMIT_LOOKUP_PUBLIC,
|
||||
rate_limit,
|
||||
)
|
||||
from services.account_service import TenantService
|
||||
from services.oauth_device_flow import (
|
||||
ACCOUNT_ISSUER_SENTINEL,
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
@ -333,18 +334,7 @@ def _audit_cross_ip_if_needed(state) -> None:
|
||||
|
||||
|
||||
def _build_account_poll_payload(account, tenant, mint) -> PollPayload:
|
||||
"""Account branch of the shared `PollPayload` contract. SSO-only fields
|
||||
(`subject_email`, `subject_issuer`) are intentionally omitted; see the
|
||||
`PollPayload` docstring in `services.oauth_device_flow`.
|
||||
"""
|
||||
from models import Tenant, TenantAccountJoin
|
||||
|
||||
rows = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.filter(TenantAccountJoin.account_id == account.id)
|
||||
.all()
|
||||
)
|
||||
rows = TenantService.get_workspaces_for_account(db.session, str(account.id))
|
||||
workspaces = [WorkspacePayload(id=str(t.id), name=t.name, role=getattr(m, "role", "")) for t, m in rows]
|
||||
# Prefer active session tenant → DB-flagged current join → first membership.
|
||||
default_ws_id = None
|
||||
|
||||
@ -17,7 +17,6 @@ import secrets
|
||||
from dataclasses import dataclass
|
||||
|
||||
from flask import jsonify, make_response, redirect, request
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import (
|
||||
BadGateway,
|
||||
BadRequest,
|
||||
@ -49,8 +48,7 @@ from libs.rate_limit import (
|
||||
enforce,
|
||||
rate_limit,
|
||||
)
|
||||
from models import Account
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.oauth_device_flow import (
|
||||
DeviceFlowRedis,
|
||||
@ -149,7 +147,7 @@ def sso_complete():
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if _email_belongs_to_dify_account(claims["email"]):
|
||||
if AccountService.has_active_account_with_email(db.session, claims["email"]):
|
||||
_emit_external_rejection_audit(
|
||||
state,
|
||||
_RejectedClaims(subject_email=claims["email"], subject_issuer=claims["issuer"]),
|
||||
@ -229,7 +227,7 @@ def approve_external():
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if _email_belongs_to_dify_account(claims.subject_email):
|
||||
if AccountService.has_active_account_with_email(db.session, claims.subject_email):
|
||||
_emit_external_rejection_audit(state, claims, reason="email_belongs_to_dify_account")
|
||||
raise Forbidden("email_belongs_to_dify_account")
|
||||
|
||||
@ -308,29 +306,6 @@ class _RejectedClaims:
|
||||
subject_issuer: str
|
||||
|
||||
|
||||
def _email_belongs_to_dify_account(email: str) -> bool:
|
||||
"""External SSO subjects whose email matches an active Dify Account must
|
||||
authenticate via the internal Dify login path (which mints dfoa_), not via
|
||||
the external SSO device flow. Returning True here blocks dfoe_ minting.
|
||||
|
||||
Pending/uninitialized/banned/closed accounts do not block: pending and
|
||||
uninitialized users may complete invitation via SSO; banned and closed
|
||||
accounts are handled by separate enforcement paths.
|
||||
"""
|
||||
if not email:
|
||||
return False
|
||||
normalized = email.strip().lower()
|
||||
if not normalized:
|
||||
return False
|
||||
row = db.session.execute(
|
||||
select(Account.id).where(
|
||||
func.lower(Account.email) == normalized,
|
||||
Account.status == AccountStatus.ACTIVE,
|
||||
),
|
||||
).scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
|
||||
def _emit_external_rejection_audit(state, claims, *, reason: str) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_rejected subject_type=%s subject_email=%s subject_issuer=%s reason=%s",
|
||||
|
||||
@ -164,6 +164,29 @@ class AccountService:
|
||||
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
|
||||
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
|
||||
|
||||
@staticmethod
|
||||
def get_account_by_email(session: Session | scoped_session, email: str) -> Account | None:
|
||||
"""Plain ``Account`` getter keyed by email. Case-sensitive — use
|
||||
:meth:`has_active_account_with_email` for the case-insensitive
|
||||
existence check that backs the SSO collision rule.
|
||||
"""
|
||||
return session.execute(select(Account).where(Account.email == email)).scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
def has_active_account_with_email(session: Session | scoped_session, email: str) -> bool:
|
||||
if not email:
|
||||
return False
|
||||
normalized = email.strip().lower()
|
||||
if not normalized:
|
||||
return False
|
||||
row = session.execute(
|
||||
select(Account.id).where(
|
||||
func.lower(Account.email) == normalized,
|
||||
Account.status == AccountStatus.ACTIVE,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
@staticmethod
|
||||
def get_account_by_id(session: Session | scoped_session, account_id: str) -> Account | None:
|
||||
"""Plain ``Account`` getter — no banned check, no tenant rotation,
|
||||
@ -1241,6 +1264,61 @@ class TenantService:
|
||||
).all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def account_belongs_to_tenant(
|
||||
session: Session | scoped_session,
|
||||
account_id: uuid.UUID | str | None,
|
||||
tenant_id: str,
|
||||
) -> bool:
|
||||
"""Existence check for ``TenantAccountJoin(account_id, tenant_id)``.
|
||||
Backs the CE-deployment membership fallback in
|
||||
``controllers.openapi.auth.strategies.MembershipStrategy``.
|
||||
|
||||
``None``/empty ``account_id`` short-circuits to ``False`` so SSO
|
||||
bearers (no account) and missing identity collapse cleanly.
|
||||
"""
|
||||
if not account_id:
|
||||
return False
|
||||
row = session.execute(
|
||||
select(TenantAccountJoin.id).where(
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
TenantAccountJoin.account_id == account_id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_by_id(session: Session | scoped_session, tenant_id: str) -> Tenant | None:
|
||||
"""Plain ``session.get(Tenant, tenant_id)`` — no status filter.
|
||||
Callers map ``status == ARCHIVE`` to their own error code (the
|
||||
openapi auth pipeline raises 403 ``workspace unavailable``).
|
||||
"""
|
||||
return session.get(Tenant, tenant_id)
|
||||
|
||||
@staticmethod
|
||||
def get_tenants_by_ids(
|
||||
session: Session | scoped_session,
|
||||
tenant_ids: list[str],
|
||||
) -> list[Tenant]:
|
||||
"""Bulk ``Tenant`` fetch by primary-key list. Order is unspecified
|
||||
— callers index by ``tenant.id`` (e.g. for cross-tenant denorm
|
||||
in ``/openapi/v1/permitted-external-apps``).
|
||||
|
||||
Empty input short-circuits to ``[]`` to avoid emitting an
|
||||
``IN ()`` SQL fragment.
|
||||
"""
|
||||
if not tenant_ids:
|
||||
return []
|
||||
return list(session.execute(select(Tenant).where(Tenant.id.in_(tenant_ids))).scalars().all())
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_name(session: Session | scoped_session, tenant_id: str) -> str | None:
|
||||
"""Single-column tenant name read. Used by openapi list endpoints
|
||||
to denormalize ``workspace_name`` onto each row without dragging
|
||||
the full ``Tenant`` ORM entity through.
|
||||
"""
|
||||
return session.execute(select(Tenant.name).where(Tenant.id == tenant_id)).scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
def find_workspace_for_account(
|
||||
session: Session | scoped_session,
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
from configs import dify_config
|
||||
from constants.model_template import default_app_templates
|
||||
@ -26,6 +28,7 @@ from models.tools import ApiToolProvider
|
||||
from services.billing_service import BillingService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.openapi.visibility import apply_openapi_gate, is_openapi_visible
|
||||
from services.tag_service import TagService
|
||||
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
|
||||
|
||||
@ -56,6 +59,51 @@ class CreateAppParams(BaseModel):
|
||||
|
||||
|
||||
class AppService:
|
||||
@staticmethod
|
||||
def get_app_by_id(
|
||||
session: Session | scoped_session,
|
||||
app_id: str,
|
||||
) -> App | None:
|
||||
return session.get(App, app_id)
|
||||
|
||||
@staticmethod
|
||||
def get_visible_app_by_id(
|
||||
session: Session | scoped_session,
|
||||
app_id: str,
|
||||
) -> App | None:
|
||||
app = session.get(App, app_id)
|
||||
if not app or app.status != "normal" or not is_openapi_visible(app):
|
||||
return None
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def find_visible_apps_by_ids(
|
||||
session: Session | scoped_session,
|
||||
app_ids: Sequence[str],
|
||||
) -> list[App]:
|
||||
if not app_ids:
|
||||
return []
|
||||
return list(session.execute(apply_openapi_gate(select(App).where(App.id.in_(list(app_ids))))).scalars().all())
|
||||
|
||||
@staticmethod
|
||||
def find_visible_apps_by_name(
|
||||
session: Session | scoped_session,
|
||||
*,
|
||||
name: str,
|
||||
tenant_id: str,
|
||||
) -> list[App]:
|
||||
return list(
|
||||
session.execute(
|
||||
apply_openapi_gate(
|
||||
select(App).where(
|
||||
App.name == name,
|
||||
App.tenant_id == tenant_id,
|
||||
App.status == "normal",
|
||||
)
|
||||
)
|
||||
).scalars()
|
||||
)
|
||||
|
||||
def get_paginate_apps(self, user_id: str, tenant_id: str, params: AppListParams) -> Pagination | None:
|
||||
"""
|
||||
Get app list with pagination
|
||||
|
||||
@ -12,7 +12,7 @@ from datetime import UTC, datetime, timedelta
|
||||
from enum import StrEnum
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from sqlalchemy import Row, and_, func, select, update
|
||||
from sqlalchemy import and_, func, select, update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
@ -492,6 +492,7 @@ def oauth_ttl_days(tenant_id: str | None = None) -> int:
|
||||
return MAX_TTL_DAYS
|
||||
return value
|
||||
|
||||
|
||||
def subject_match_clauses(ctx: AuthContext) -> tuple[Any, ...]:
|
||||
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||
return (OAuthAccessToken.account_id == str(ctx.account_id),)
|
||||
@ -509,9 +510,7 @@ def list_active_sessions(
|
||||
) -> list[OAuthAccessToken]:
|
||||
return list(
|
||||
session.execute(
|
||||
select(
|
||||
OAuthAccessToken
|
||||
)
|
||||
select(OAuthAccessToken)
|
||||
.where(
|
||||
and_(
|
||||
*subject_match_clauses(ctx),
|
||||
|
||||
@ -52,11 +52,12 @@ def test_acl_strategy_subject_mode_matrix(ent, access_mode, subject_type, expect
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._has_tenant_membership")
|
||||
def test_membership_strategy_uses_join_lookup(member):
|
||||
@patch("controllers.openapi.auth.strategies.TenantService.account_belongs_to_tenant")
|
||||
@patch("controllers.openapi.auth.strategies.db")
|
||||
def test_membership_strategy_uses_join_lookup(db_mock, member):
|
||||
member.return_value = True
|
||||
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
|
||||
member.assert_called_once_with("acc1", "t1")
|
||||
member.assert_called_once_with(db_mock.session, "acc1", "t1")
|
||||
|
||||
|
||||
def test_membership_strategy_rejects_external_sso():
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/."""
|
||||
|
||||
import builtins
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -9,7 +8,6 @@ from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device_sso import (
|
||||
_email_belongs_to_dify_account,
|
||||
approval_context,
|
||||
approve_external,
|
||||
sso_complete,
|
||||
@ -79,27 +77,3 @@ def test_sso_complete_idp_callback_url_uses_canonical_path():
|
||||
from controllers.openapi import oauth_device_sso
|
||||
|
||||
assert oauth_device_sso._SSO_COMPLETE_PATH == "/openapi/v1/oauth/device/sso-complete"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("email", "row", "expected"),
|
||||
[
|
||||
("alice@example.com", "acc1", True),
|
||||
("alice@example.com", None, False),
|
||||
("Alice@Example.COM", "acc1", True), # case-insensitive lookup
|
||||
(" alice@example.com ", "acc1", True), # surrounding whitespace stripped
|
||||
("", "acc1", False),
|
||||
(" ", "acc1", False),
|
||||
("", None, False),
|
||||
],
|
||||
)
|
||||
@patch("controllers.openapi.oauth_device_sso.db")
|
||||
def test_email_belongs_to_dify_account(db_mock, email, row, expected):
|
||||
exec_result = MagicMock()
|
||||
exec_result.scalar_one_or_none.return_value = row
|
||||
db_mock.session.execute.return_value = exec_result
|
||||
assert _email_belongs_to_dify_account(email) is expected
|
||||
if email.strip():
|
||||
db_mock.session.execute.assert_called_once()
|
||||
else:
|
||||
db_mock.session.execute.assert_not_called()
|
||||
|
||||
@ -2021,6 +2021,44 @@ class TestSessionInjectedGetters:
|
||||
|
||||
assert AccountService.get_account_by_id(mock_session, "missing") is None
|
||||
|
||||
def test_get_account_by_email_returns_scalar_or_none(self):
|
||||
"""Plain getter — case-sensitive equality (callers needing the
|
||||
case-insensitive existence check use
|
||||
:meth:`has_active_account_with_email`).
|
||||
"""
|
||||
mock_session = MagicMock()
|
||||
sentinel = MagicMock(spec=Account)
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = sentinel
|
||||
|
||||
assert AccountService.get_account_by_email(mock_session, "alice@example.com") is sentinel
|
||||
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
assert AccountService.get_account_by_email(mock_session, "ghost@example.com") is None
|
||||
|
||||
def test_account_belongs_to_tenant_short_circuits_on_falsy_account_id(self):
|
||||
"""SSO bearers with no ``account_id`` (and any other falsy id)
|
||||
must collapse to ``False`` without a DB round-trip — that's the
|
||||
contract :class:`MembershipStrategy` relies on.
|
||||
"""
|
||||
mock_session = MagicMock()
|
||||
|
||||
assert TenantService.account_belongs_to_tenant(mock_session, None, "tenant-1") is False
|
||||
assert TenantService.account_belongs_to_tenant(mock_session, "", "tenant-1") is False
|
||||
mock_session.execute.assert_not_called()
|
||||
|
||||
def test_account_belongs_to_tenant_true_when_join_row_exists(self):
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = "join-id"
|
||||
|
||||
assert TenantService.account_belongs_to_tenant(mock_session, "user-1", "tenant-1") is True
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_account_belongs_to_tenant_false_when_no_join(self):
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
assert TenantService.account_belongs_to_tenant(mock_session, "user-1", "tenant-1") is False
|
||||
|
||||
def test_get_account_memberships_returns_join_tenant_pairs(self):
|
||||
"""Returns whatever ``session.query(...).join(...).filter(...).all()``
|
||||
produces — ordering unspecified, callers pick the default
|
||||
@ -2049,6 +2087,54 @@ class TestSessionInjectedGetters:
|
||||
assert out == rows
|
||||
assert mock_session.execute.called
|
||||
|
||||
def test_get_tenant_by_id_is_plain_session_get(self):
|
||||
"""``get_tenant_by_id`` must NOT apply a status filter — the
|
||||
openapi auth pipeline needs to map ``status == ARCHIVE`` to a
|
||||
403, distinct from a 404 for "missing".
|
||||
"""
|
||||
from models import Tenant
|
||||
|
||||
mock_session = MagicMock()
|
||||
sentinel = MagicMock(spec=Tenant)
|
||||
mock_session.get.return_value = sentinel
|
||||
|
||||
assert TenantService.get_tenant_by_id(mock_session, "tenant-1") is sentinel
|
||||
mock_session.get.assert_called_once_with(Tenant, "tenant-1")
|
||||
|
||||
def test_get_tenant_by_id_returns_none_when_missing(self):
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
|
||||
assert TenantService.get_tenant_by_id(mock_session, "missing") is None
|
||||
|
||||
def test_get_tenants_by_ids_short_circuits_on_empty_input(self):
|
||||
"""Empty id list must not emit ``WHERE id IN ()``."""
|
||||
mock_session = MagicMock()
|
||||
|
||||
assert TenantService.get_tenants_by_ids(mock_session, []) == []
|
||||
mock_session.execute.assert_not_called()
|
||||
|
||||
def test_get_tenants_by_ids_returns_scalars(self):
|
||||
mock_session = MagicMock()
|
||||
tenants = [MagicMock(), MagicMock()]
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = tenants
|
||||
|
||||
assert TenantService.get_tenants_by_ids(mock_session, ["t1", "t2"]) == tenants
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_get_tenant_name_returns_scalar_or_none(self):
|
||||
"""Single-column lookup: ``session.execute(...).scalar_one_or_none()``
|
||||
— used by openapi list endpoints to denormalise
|
||||
``workspace_name`` onto each row.
|
||||
"""
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = "Acme Inc."
|
||||
|
||||
assert TenantService.get_tenant_name(mock_session, "tenant-1") == "Acme Inc."
|
||||
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
assert TenantService.get_tenant_name(mock_session, "missing") is None
|
||||
|
||||
def test_find_workspace_for_account_returns_first_row_or_none(self):
|
||||
"""Per-id read returns ``session.execute(...).first()`` directly;
|
||||
callers map ``None`` → 404 to avoid leaking workspace IDs across
|
||||
@ -2058,10 +2144,7 @@ class TestSessionInjectedGetters:
|
||||
sentinel_row = (MagicMock(), MagicMock())
|
||||
mock_session.execute.return_value.first.return_value = sentinel_row
|
||||
|
||||
assert (
|
||||
TenantService.find_workspace_for_account(mock_session, "user-123", "ws-1")
|
||||
is sentinel_row
|
||||
)
|
||||
assert TenantService.find_workspace_for_account(mock_session, "user-123", "ws-1") is sentinel_row
|
||||
|
||||
mock_session.execute.return_value.first.return_value = None
|
||||
assert TenantService.find_workspace_for_account(mock_session, "user-123", "ws-1") is None
|
||||
|
||||
129
api/tests/unit_tests/services/test_app_service.py
Normal file
129
api/tests/unit_tests/services/test_app_service.py
Normal file
@ -0,0 +1,129 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from models.model import App
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
class TestOpenapiVisibilityHelpers:
|
||||
"""Coverage for the session-injected, openapi-visibility-scoped
|
||||
``AppService`` getters used by ``/openapi/v1/apps*``. These helpers
|
||||
centralise the "row exists + status normal + openapi-visibility
|
||||
gate passes" check so the controller can stay free of SQL.
|
||||
"""
|
||||
|
||||
def test_get_app_by_id_is_plain_session_get(self):
|
||||
"""``get_app_by_id`` must NOT apply status / visibility filters
|
||||
— callers (e.g. the openapi auth pipeline) need to differentiate
|
||||
404 (missing) from 403 (``enable_api`` off) and would lose that
|
||||
signal if the helper coalesced both into ``None``.
|
||||
"""
|
||||
mock_session = MagicMock()
|
||||
sentinel_app = MagicMock(spec=App)
|
||||
sentinel_app.status = "archived" # explicitly NOT "normal"
|
||||
mock_session.get.return_value = sentinel_app
|
||||
|
||||
assert AppService.get_app_by_id(mock_session, "app-uuid") is sentinel_app
|
||||
mock_session.get.assert_called_once_with(App, "app-uuid")
|
||||
|
||||
def test_get_app_by_id_returns_none_when_missing(self):
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
|
||||
assert AppService.get_app_by_id(mock_session, "missing") is None
|
||||
|
||||
def test_get_visible_app_by_id_returns_app_when_visible(self):
|
||||
mock_session = MagicMock()
|
||||
app = MagicMock(spec=App)
|
||||
app.status = "normal"
|
||||
mock_session.get.return_value = app
|
||||
|
||||
with patch("services.app_service.is_openapi_visible", return_value=True):
|
||||
assert AppService.get_visible_app_by_id(mock_session, "app-uuid") is app
|
||||
|
||||
mock_session.get.assert_called_once_with(App, "app-uuid")
|
||||
|
||||
def test_get_visible_app_by_id_returns_none_when_row_missing(self):
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
|
||||
assert AppService.get_visible_app_by_id(mock_session, "missing") is None
|
||||
|
||||
def test_get_visible_app_by_id_returns_none_when_status_not_normal(self):
|
||||
"""Soft-deleted/archived rows must not surface on the openapi
|
||||
surface — the helper hides them by returning ``None``.
|
||||
"""
|
||||
mock_session = MagicMock()
|
||||
app = MagicMock(spec=App)
|
||||
app.status = "archived"
|
||||
mock_session.get.return_value = app
|
||||
|
||||
with patch("services.app_service.is_openapi_visible", return_value=True):
|
||||
assert AppService.get_visible_app_by_id(mock_session, "app-uuid") is None
|
||||
|
||||
def test_get_visible_app_by_id_returns_none_when_visibility_gate_rejects(self):
|
||||
"""``is_openapi_visible`` is the per-row counterpart to
|
||||
``apply_openapi_gate`` — when it returns False the helper must
|
||||
treat the row as invisible (not "found but unauthorized").
|
||||
"""
|
||||
mock_session = MagicMock()
|
||||
app = MagicMock(spec=App)
|
||||
app.status = "normal"
|
||||
mock_session.get.return_value = app
|
||||
|
||||
with patch("services.app_service.is_openapi_visible", return_value=False):
|
||||
assert AppService.get_visible_app_by_id(mock_session, "app-uuid") is None
|
||||
|
||||
def test_find_visible_apps_by_name_returns_scalars_through_visibility_gate(self):
|
||||
"""Tenant-scoped name lookup. The helper passes the SELECT through
|
||||
``apply_openapi_gate`` and materialises ``.scalars()`` into a list
|
||||
so the controller can branch on length (404 / single / 409).
|
||||
"""
|
||||
mock_session = MagicMock()
|
||||
rows = [MagicMock(spec=App), MagicMock(spec=App)]
|
||||
mock_session.execute.return_value.scalars.return_value = iter(rows)
|
||||
|
||||
with patch("services.app_service.apply_openapi_gate", side_effect=lambda q: q) as gate:
|
||||
out = AppService.find_visible_apps_by_name(mock_session, name="my-app", tenant_id="tenant-1")
|
||||
|
||||
assert out == rows
|
||||
# Visibility gate must wrap the SELECT exactly once.
|
||||
gate.assert_called_once()
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_find_visible_apps_by_name_returns_empty_list_on_no_match(self):
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute.return_value.scalars.return_value = iter([])
|
||||
|
||||
with patch("services.app_service.apply_openapi_gate", side_effect=lambda q: q):
|
||||
out = AppService.find_visible_apps_by_name(mock_session, name="nope", tenant_id="tenant-1")
|
||||
|
||||
assert out == []
|
||||
|
||||
def test_find_visible_apps_by_ids_short_circuits_on_empty_input(self):
|
||||
"""Empty id list must not emit ``WHERE id IN ()`` — Postgres
|
||||
rejects empty IN lists and the call is a guaranteed no-op
|
||||
anyway. The helper returns ``[]`` without touching the session.
|
||||
"""
|
||||
mock_session = MagicMock()
|
||||
|
||||
assert AppService.find_visible_apps_by_ids(mock_session, []) == []
|
||||
mock_session.execute.assert_not_called()
|
||||
|
||||
def test_find_visible_apps_by_ids_passes_through_visibility_gate(self):
|
||||
"""Bulk fetch routes through ``apply_openapi_gate`` exactly once
|
||||
and materialises the scalar rows. **No** status filter is
|
||||
applied here — the EE permitted-external pipeline filters
|
||||
non-normal hits in Python so its page count stays anchored.
|
||||
"""
|
||||
mock_session = MagicMock()
|
||||
rows = [MagicMock(spec=App), MagicMock(spec=App)]
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = rows
|
||||
|
||||
with patch("services.app_service.apply_openapi_gate", side_effect=lambda q: q) as gate:
|
||||
out = AppService.find_visible_apps_by_ids(mock_session, ["a", "b"])
|
||||
|
||||
assert out == rows
|
||||
gate.assert_called_once()
|
||||
mock_session.execute.assert_called_once()
|
||||
Reference in New Issue
Block a user