mirror of
https://github.com/langgenius/dify.git
synced 2026-06-03 07:26:39 +08:00
Compare commits
5 Commits
deploy/ent
...
fix/server
| Author | SHA1 | Date | |
|---|---|---|---|
| c39861b33e | |||
| f591da7865 | |||
| f19679b217 | |||
| b682591c7a | |||
| 8f6b59feff |
20
.github/workflows/autofix.yml
vendored
20
.github/workflows/autofix.yml
vendored
@ -51,6 +51,15 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
- name: Check dify-agent inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: dify-agent-changes
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
dify-agent/**/*.py
|
||||
dify-agent/pyproject.toml
|
||||
dify-agent/uv.lock
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
|
||||
with:
|
||||
@ -76,6 +85,17 @@ jobs:
|
||||
# Format code
|
||||
uv run ruff format ..
|
||||
|
||||
- if: github.event_name != 'merge_group' && steps.dify-agent-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd dify-agent
|
||||
uv sync --dev
|
||||
# fmt first to avoid line too long
|
||||
uv run ruff format .
|
||||
# Fix lint errors
|
||||
uv run ruff check --fix .
|
||||
# Format code
|
||||
uv run ruff format .
|
||||
|
||||
- name: count migration progress
|
||||
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
|
||||
@ -147,7 +147,7 @@ class AppDescribeApi(AppReadResource):
|
||||
class AppListApi(Resource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
|
||||
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
|
||||
@auth_router.guard_workspace(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
try:
|
||||
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.conditions import (
|
||||
EDITION_CE,
|
||||
EDITION_EE,
|
||||
HAS_ALLOWED_ROLES,
|
||||
LOADED_APP_IS_PRIVATE,
|
||||
PATH_HAS_APP_ID,
|
||||
WEBAPP_AUTH_ENABLED,
|
||||
WORKSPACE_MEMBERSHIP_REQUIRED,
|
||||
)
|
||||
from controllers.openapi.auth.data import Edition
|
||||
from controllers.openapi.auth.flow import When
|
||||
@ -16,18 +15,14 @@ from controllers.openapi.auth.prepare import (
|
||||
load_app,
|
||||
load_app_access_mode,
|
||||
load_tenant,
|
||||
load_tenant_from_request,
|
||||
load_workspace_role,
|
||||
resolve_external_user,
|
||||
)
|
||||
from controllers.openapi.auth.verify import (
|
||||
check_acl,
|
||||
check_app_api_enabled,
|
||||
check_app_access,
|
||||
check_membership,
|
||||
check_private_app_permission,
|
||||
check_scope,
|
||||
check_workspace_mismatch,
|
||||
check_workspace_role,
|
||||
)
|
||||
from libs.oauth_bearer import TokenType
|
||||
|
||||
@ -35,17 +30,13 @@ account_pipeline = AuthPipeline(
|
||||
prepare=[
|
||||
When(PATH_HAS_APP_ID, then=load_app),
|
||||
When(PATH_HAS_APP_ID, then=load_tenant),
|
||||
When(WORKSPACE_MEMBERSHIP_REQUIRED & ~PATH_HAS_APP_ID, then=load_tenant_from_request),
|
||||
load_account,
|
||||
When(HAS_ALLOWED_ROLES, then=load_workspace_role),
|
||||
load_account, # all tokens here are account tokens
|
||||
When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
|
||||
check_scope,
|
||||
When((PATH_HAS_APP_ID | WORKSPACE_MEMBERSHIP_REQUIRED) & ~HAS_ALLOWED_ROLES, then=check_membership),
|
||||
When(WORKSPACE_MEMBERSHIP_REQUIRED & PATH_HAS_APP_ID, then=check_workspace_mismatch),
|
||||
When(HAS_ALLOWED_ROLES, then=check_workspace_role),
|
||||
When(EDITION_CE & PATH_HAS_APP_ID, then=check_membership),
|
||||
When(EDITION_EE & PATH_HAS_APP_ID & ~WEBAPP_AUTH_ENABLED, then=check_app_access),
|
||||
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
],
|
||||
@ -59,7 +50,6 @@ external_sso_pipeline = AuthPipeline(
|
||||
When(PATH_HAS_APP_ID, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
|
||||
check_scope,
|
||||
When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
|
||||
@ -50,7 +50,4 @@ EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
|
||||
|
||||
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
|
||||
|
||||
WORKSPACE_MEMBERSHIP_REQUIRED = request_cond(lambda ctx: ctx.workspace_membership)
|
||||
HAS_ALLOWED_ROLES = request_cond(lambda ctx: ctx.allowed_roles is not None)
|
||||
|
||||
LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE)
|
||||
|
||||
@ -9,7 +9,7 @@ from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from configs import dify_config
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import Account, Tenant, TenantAccountRole
|
||||
from models.account import Account, Tenant
|
||||
from models.model import App, EndUser
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
@ -41,8 +41,6 @@ class RequestContext(BaseModel):
|
||||
token_type: TokenType
|
||||
scope: Scope | None = None
|
||||
path_params: dict[str, str]
|
||||
workspace_membership: bool = False
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None
|
||||
|
||||
|
||||
class AuthData(BaseModel):
|
||||
@ -58,14 +56,10 @@ class AuthData(BaseModel):
|
||||
external_identity: ExternalIdentity | None = None
|
||||
path_params: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None
|
||||
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
app_access_mode: WebAppAccessMode | None = None
|
||||
|
||||
tenant_role: TenantAccountRole | None = None
|
||||
|
||||
caller: Account | EndUser | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
|
||||
|
||||
@ -34,7 +34,6 @@ from libs.oauth_bearer import (
|
||||
reset_auth_ctx,
|
||||
set_auth_ctx,
|
||||
)
|
||||
from models.account import TenantAccountRole
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
||||
|
||||
@ -57,15 +56,11 @@ class AuthPipeline:
|
||||
view: Callable,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Any:
|
||||
req_ctx = RequestContext(
|
||||
token_type=identity.token_type,
|
||||
scope=scope,
|
||||
path_params=dict(request.view_args or {}),
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
data = AuthData(
|
||||
@ -76,7 +71,6 @@ class AuthPipeline:
|
||||
scopes=frozenset(identity.scopes),
|
||||
tenants=dict(identity.verified_tenants),
|
||||
required_scope=scope,
|
||||
allowed_roles=allowed_roles,
|
||||
path_params=dict(req_ctx.path_params),
|
||||
external_identity=(
|
||||
ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer)
|
||||
@ -127,41 +121,6 @@ class PipelineRouter:
|
||||
scope: Scope | None = None,
|
||||
allowed_token_types: frozenset[TokenType] | None = None,
|
||||
edition: frozenset[Edition] | None = None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Callable:
|
||||
return self._make_decorator(
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
def guard_workspace(
|
||||
self,
|
||||
*,
|
||||
scope: Scope | None = None,
|
||||
allowed_token_types: frozenset[TokenType] | None = None,
|
||||
edition: frozenset[Edition] | None = None,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Callable:
|
||||
return self._make_decorator(
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
workspace_membership=True,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
def _make_decorator(
|
||||
self,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
allowed_token_types: frozenset[TokenType] | None,
|
||||
edition: frozenset[Edition] | None,
|
||||
workspace_membership: bool,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None,
|
||||
) -> Callable:
|
||||
def decorator(view: Callable) -> Callable:
|
||||
@wraps(view)
|
||||
@ -173,8 +132,6 @@ class PipelineRouter:
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
return decorated
|
||||
@ -190,8 +147,6 @@ class PipelineRouter:
|
||||
scope: Scope | None,
|
||||
allowed_token_types: frozenset[TokenType] | None,
|
||||
edition: frozenset[Edition] | None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Any:
|
||||
# 404 not 403 — this edition doesn't expose the feature at all
|
||||
if edition is not None and current_edition() not in edition:
|
||||
@ -227,15 +182,7 @@ class PipelineRouter:
|
||||
if not license_checked and Edition.EE in route.required_edition:
|
||||
_check_license()
|
||||
|
||||
return route.pipeline._run(
|
||||
identity,
|
||||
args,
|
||||
kwargs,
|
||||
view,
|
||||
scope=scope,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
return route.pipeline._run(identity, args, kwargs, view, scope=scope)
|
||||
|
||||
|
||||
def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool:
|
||||
|
||||
@ -1,8 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
@ -16,18 +13,16 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppAcce
|
||||
|
||||
|
||||
def load_app(data: AuthData) -> None:
|
||||
if data.app is not None:
|
||||
return
|
||||
app_id = data.path_params["app_id"]
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
data.app = app
|
||||
|
||||
|
||||
def load_tenant(data: AuthData) -> None:
|
||||
if data.tenant is not None:
|
||||
return
|
||||
if data.app is None:
|
||||
raise InternalServerError("pipeline_invariant_violated: app not loaded before load_tenant")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, str(data.app.tenant_id))
|
||||
@ -36,25 +31,7 @@ def load_tenant(data: AuthData) -> None:
|
||||
data.tenant = tenant
|
||||
|
||||
|
||||
def load_tenant_from_request(data: AuthData) -> None:
|
||||
if data.tenant is not None:
|
||||
return
|
||||
workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
|
||||
if not workspace_id:
|
||||
raise NotFound("workspace not found")
|
||||
try:
|
||||
uuid.UUID(workspace_id)
|
||||
except ValueError:
|
||||
raise NotFound("workspace not found")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, workspace_id)
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise NotFound("workspace not found")
|
||||
data.tenant = tenant
|
||||
|
||||
|
||||
def load_account(data: AuthData) -> None:
|
||||
if data.caller is not None:
|
||||
return
|
||||
account = AccountService.get_account_by_id(db.session, str(data.account_id))
|
||||
if account is None:
|
||||
raise Unauthorized("account not found")
|
||||
@ -64,19 +41,6 @@ def load_account(data: AuthData) -> None:
|
||||
data.caller_kind = "account"
|
||||
|
||||
|
||||
def load_workspace_role(data: AuthData) -> None:
|
||||
if data.tenant_role is not None:
|
||||
return
|
||||
if data.tenant is None or data.account_id is None:
|
||||
return
|
||||
if data.caller is not None and getattr(data.caller, "status", None) != "active":
|
||||
return
|
||||
role = TenantService.get_account_role_in_tenant(db.session, str(data.account_id), str(data.tenant.id))
|
||||
if role is None:
|
||||
return
|
||||
data.tenant_role = role
|
||||
|
||||
|
||||
def resolve_external_user(data: AuthData) -> None:
|
||||
if data.tenant is None or data.app is None or data.external_identity is None:
|
||||
raise Unauthorized("missing context for external user resolution")
|
||||
|
||||
77
api/controllers/openapi/auth/role_gate.py
Normal file
77
api/controllers/openapi/auth/role_gate.py
Normal file
@ -0,0 +1,77 @@
|
||||
"""Workspace role gate.
|
||||
|
||||
Layered on top of `validate_bearer` + `accept_subjects(SubjectType.ACCOUNT)`
|
||||
for routes whose access depends on the caller's `TenantAccountJoin.role`
|
||||
in the workspace named by the `workspace_id` path parameter.
|
||||
|
||||
Usage::
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members")
|
||||
class Members(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
@require_workspace_role() # any member
|
||||
def get(self, workspace_id: str): ...
|
||||
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def post(self, workspace_id: str): ...
|
||||
|
||||
Non-member callers get 404 (matching `GET /openapi/v1/workspaces/<id>`)
|
||||
so workspace IDs do not leak across tenants. A member without one of the
|
||||
allowed roles gets 403.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TypeVar
|
||||
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import try_get_auth_ctx
|
||||
from models.account import TenantAccountRole
|
||||
from services.account_service import TenantService
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
|
||||
def require_workspace_role(*allowed_roles: TenantAccountRole) -> Callable[[F], F]:
|
||||
"""Gate a route on the caller's role in ``workspace_id``.
|
||||
|
||||
Pass no roles to require only membership. Pass one or more roles to
|
||||
require the caller's role be in that set.
|
||||
"""
|
||||
|
||||
allowed = frozenset(allowed_roles)
|
||||
|
||||
def deco(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: object, **kwargs: object) -> object:
|
||||
ctx = try_get_auth_ctx()
|
||||
if ctx is None or ctx.account_id is None:
|
||||
raise RuntimeError(
|
||||
"require_workspace_role called without account-bearer context; "
|
||||
"stack validate_bearer + accept_subjects(SubjectType.ACCOUNT) above it"
|
||||
)
|
||||
|
||||
workspace_id = kwargs.get("workspace_id")
|
||||
if not workspace_id:
|
||||
raise RuntimeError("require_workspace_role expects a 'workspace_id' route parameter")
|
||||
|
||||
role = TenantService.get_account_role_in_tenant(db.session, str(ctx.account_id), str(workspace_id))
|
||||
|
||||
if role is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
if allowed and role not in allowed:
|
||||
raise Forbidden("insufficient workspace role")
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return deco
|
||||
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized, UnprocessableEntity
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
@ -31,30 +30,6 @@ def check_membership(data: AuthData) -> None:
|
||||
)
|
||||
|
||||
|
||||
def check_workspace_mismatch(data: AuthData) -> None:
|
||||
if data.tenant is None:
|
||||
return
|
||||
request_workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
|
||||
if request_workspace_id and request_workspace_id != str(data.tenant.id):
|
||||
raise UnprocessableEntity("workspace_id does not match app's workspace")
|
||||
|
||||
|
||||
def check_workspace_role(data: AuthData) -> None:
|
||||
if data.allowed_roles is None:
|
||||
return
|
||||
if data.tenant_role is None:
|
||||
raise NotFound("workspace not found")
|
||||
if data.tenant_role not in data.allowed_roles:
|
||||
raise Forbidden("insufficient workspace role")
|
||||
|
||||
|
||||
def check_app_api_enabled(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
return
|
||||
if not data.app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
|
||||
|
||||
def check_app_access(data: AuthData) -> None:
|
||||
if data.tenant is None:
|
||||
return
|
||||
|
||||
@ -5,8 +5,9 @@ endpoints. Account bearers (dfoa_) see every tenant they're a member of.
|
||||
External SSO bearers (dfoe_) have no account_id and so see an empty list —
|
||||
that matches /openapi/v1/account.
|
||||
|
||||
Member-management endpoints use ``guard_workspace`` which enforces
|
||||
workspace membership and optional role requirements via the auth pipeline.
|
||||
Member-management endpoints are gated by both `accept_subjects` (SSO out)
|
||||
and `require_workspace_role` (membership / role lookup against the path's
|
||||
``workspace_id``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -36,6 +37,7 @@ from controllers.openapi._models import (
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.role_gate import require_workspace_role
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import Account, Tenant, TenantAccountJoin
|
||||
@ -150,7 +152,8 @@ class WorkspaceSwitchApi(Resource):
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role()
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
account = _load_account(auth_data.account_id)
|
||||
|
||||
@ -176,7 +179,8 @@ class WorkspaceMembersApi(Resource):
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(MemberListQuery))
|
||||
@openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__])
|
||||
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role()
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData):
|
||||
try:
|
||||
query = MemberListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -198,11 +202,8 @@ class WorkspaceMembersApi(Resource):
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__])
|
||||
@openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__])
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberInvitePayload)
|
||||
inviter = _load_account(auth_data.account_id)
|
||||
@ -252,11 +253,8 @@ class WorkspaceMemberApi(Resource):
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
operator = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
@ -286,11 +284,8 @@ class WorkspaceMemberRoleApi(Resource):
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__])
|
||||
@openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberRoleUpdatePayload)
|
||||
operator = _load_account(auth_data.account_id)
|
||||
|
||||
@ -1329,9 +1329,9 @@ class TenantService:
|
||||
) -> TenantAccountRole | None:
|
||||
"""Return the caller's role in ``tenant_id``, or ``None`` if not a member.
|
||||
|
||||
Backs the openapi auth pipeline's ``load_workspace_role`` prepare step:
|
||||
``None`` is treated as non-member (the pipeline maps it to 404 — no
|
||||
cross-tenant ID leak) and an out-of-set role to 403.
|
||||
Backs ``controllers.openapi.auth.role_gate.require_workspace_role``:
|
||||
the gate maps ``None`` to 404 (non-member — no cross-tenant ID leak)
|
||||
and an out-of-set role to 403, so it never touches the ORM itself.
|
||||
|
||||
``None``/empty ``account_id`` short-circuits to ``None`` so SSO
|
||||
bearers (no account) collapse to the non-member path. Mirrors the
|
||||
|
||||
@ -16,20 +16,20 @@ def test_auth_router_is_pipeline_router():
|
||||
assert isinstance(auth_router, PipelineRouter)
|
||||
|
||||
|
||||
def test_account_pipeline_prepare_has_six_entries():
|
||||
assert len(account_pipeline._prepare) == 6
|
||||
def test_account_pipeline_prepare_has_four_entries():
|
||||
assert len(account_pipeline._prepare) == 4
|
||||
|
||||
|
||||
def test_account_auth_list_has_seven_entries():
|
||||
assert len(account_pipeline._auth) == 7
|
||||
def test_account_auth_list_has_five_entries():
|
||||
assert len(account_pipeline._auth) == 5
|
||||
|
||||
|
||||
def test_external_sso_pipeline_prepare_has_four_entries():
|
||||
assert len(external_sso_pipeline._prepare) == 4
|
||||
|
||||
|
||||
def test_external_sso_auth_list_has_four_entries():
|
||||
assert len(external_sso_pipeline._auth) == 4
|
||||
def test_external_sso_auth_list_has_three_entries():
|
||||
assert len(external_sso_pipeline._auth) == 3
|
||||
|
||||
|
||||
def test_account_pipeline_has_unconditional_load_account():
|
||||
@ -41,14 +41,17 @@ def test_external_sso_pipeline_all_prepare_entries_are_when():
|
||||
assert all(isinstance(s, When) for s in external_sso_pipeline._prepare)
|
||||
|
||||
|
||||
def test_account_pipeline_has_one_unconditional_auth_step():
|
||||
non_when = [s for s in account_pipeline._auth if not isinstance(s, When)]
|
||||
assert len(non_when) == 1
|
||||
def test_first_auth_entry_is_check_scope_in_both_pipelines():
|
||||
assert not isinstance(account_pipeline._auth[0], When)
|
||||
assert not isinstance(external_sso_pipeline._auth[0], When)
|
||||
|
||||
|
||||
def test_external_sso_pipeline_has_one_unconditional_auth_step():
|
||||
non_when = [s for s in external_sso_pipeline._auth if not isinstance(s, When)]
|
||||
assert len(non_when) == 1
|
||||
def test_remaining_auth_entries_are_when_for_account():
|
||||
assert all(isinstance(s, When) for s in account_pipeline._auth[1:])
|
||||
|
||||
|
||||
def test_remaining_auth_entries_are_when_for_external_sso():
|
||||
assert all(isinstance(s, When) for s in external_sso_pipeline._auth[1:])
|
||||
|
||||
|
||||
def test_router_routes_contain_both_token_types():
|
||||
|
||||
@ -4,13 +4,11 @@ from controllers.openapi.auth.conditions import (
|
||||
EDITION_CE,
|
||||
EDITION_EE,
|
||||
EDITION_SAAS,
|
||||
HAS_ALLOWED_ROLES,
|
||||
LOADED_APP_IS_PRIVATE,
|
||||
PATH_HAS_APP_ID,
|
||||
TOKEN_IS_OAUTH_ACCOUNT,
|
||||
TOKEN_IS_OAUTH_EXTERNAL_SSO,
|
||||
WEBAPP_AUTH_ENABLED,
|
||||
WORKSPACE_MEMBERSHIP_REQUIRED,
|
||||
Cond,
|
||||
config_cond,
|
||||
data_cond,
|
||||
@ -18,15 +16,13 @@ from controllers.openapi.auth.conditions import (
|
||||
)
|
||||
from controllers.openapi.auth.data import AuthData, Edition, RequestContext
|
||||
from libs.oauth_bearer import TokenType
|
||||
from models.account import TenantAccountRole
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
|
||||
def _ctx(token_type=TokenType.OAUTH_ACCOUNT, path_params=None, **kwargs):
|
||||
def _ctx(token_type=TokenType.OAUTH_ACCOUNT, path_params=None):
|
||||
return RequestContext(
|
||||
token_type=token_type,
|
||||
path_params=path_params or {},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -145,28 +141,3 @@ def test_loaded_app_is_private():
|
||||
assert LOADED_APP_IS_PRIVATE(_ctx(), data_public) is False
|
||||
assert LOADED_APP_IS_PRIVATE(_ctx(), data_none) is False
|
||||
assert LOADED_APP_IS_PRIVATE(_ctx(), None) is False
|
||||
|
||||
|
||||
def test_workspace_membership_required_true():
|
||||
assert WORKSPACE_MEMBERSHIP_REQUIRED(_ctx(workspace_membership=True)) is True
|
||||
|
||||
|
||||
def test_workspace_membership_required_false():
|
||||
assert WORKSPACE_MEMBERSHIP_REQUIRED(_ctx(workspace_membership=False)) is False
|
||||
|
||||
|
||||
def test_workspace_membership_required_default():
|
||||
assert WORKSPACE_MEMBERSHIP_REQUIRED(_ctx()) is False
|
||||
|
||||
|
||||
def test_has_allowed_roles_true():
|
||||
ctx = _ctx(allowed_roles=frozenset({TenantAccountRole.OWNER}))
|
||||
assert HAS_ALLOWED_ROLES(ctx) is True
|
||||
|
||||
|
||||
def test_has_allowed_roles_false():
|
||||
assert HAS_ALLOWED_ROLES(_ctx(allowed_roles=None)) is False
|
||||
|
||||
|
||||
def test_has_allowed_roles_default():
|
||||
assert HAS_ALLOWED_ROLES(_ctx()) is False
|
||||
|
||||
@ -115,69 +115,3 @@ def test_auth_data_token_id_optional():
|
||||
scopes=frozenset(),
|
||||
)
|
||||
assert data.token_id is None
|
||||
|
||||
|
||||
def test_request_context_workspace_membership_default_false():
|
||||
ctx = RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={})
|
||||
assert ctx.workspace_membership is False
|
||||
|
||||
|
||||
def test_request_context_workspace_membership_set():
|
||||
ctx = RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={}, workspace_membership=True)
|
||||
assert ctx.workspace_membership is True
|
||||
|
||||
|
||||
def test_request_context_allowed_roles_default_none():
|
||||
ctx = RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={})
|
||||
assert ctx.allowed_roles is None
|
||||
|
||||
|
||||
def test_request_context_allowed_roles_set():
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
roles = frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN})
|
||||
ctx = RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={}, allowed_roles=roles)
|
||||
assert ctx.allowed_roles == roles
|
||||
|
||||
|
||||
def test_auth_data_allowed_roles_default_none():
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
token_hash="abc",
|
||||
scopes=frozenset(),
|
||||
)
|
||||
assert data.allowed_roles is None
|
||||
|
||||
|
||||
def test_auth_data_allowed_roles_set():
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
roles = frozenset({TenantAccountRole.ADMIN})
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
token_hash="abc",
|
||||
scopes=frozenset(),
|
||||
allowed_roles=roles,
|
||||
)
|
||||
assert data.allowed_roles == roles
|
||||
|
||||
|
||||
def test_auth_data_tenant_role_default_none():
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
token_hash="abc",
|
||||
scopes=frozenset(),
|
||||
)
|
||||
assert data.tenant_role is None
|
||||
|
||||
|
||||
def test_auth_data_tenant_role_set():
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
token_hash="abc",
|
||||
scopes=frozenset(),
|
||||
tenant_role=TenantAccountRole.ADMIN,
|
||||
)
|
||||
assert data.tenant_role == TenantAccountRole.ADMIN
|
||||
|
||||
@ -247,60 +247,6 @@ def test_guard_populates_external_identity_from_subject_email(app):
|
||||
assert received["data"].external_identity.issuer == "https://idp.example.com"
|
||||
|
||||
|
||||
def test_guard_workspace_sets_membership_and_roles(app):
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
router = _make_router()
|
||||
received = {}
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with (
|
||||
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
|
||||
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
|
||||
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
|
||||
):
|
||||
mock_auth.return_value.authenticate.return_value = _fake_identity()
|
||||
|
||||
roles = frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN})
|
||||
|
||||
@router.guard_workspace(
|
||||
scope=Scope.FULL,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=roles,
|
||||
)
|
||||
def view(*, auth_data):
|
||||
received["data"] = auth_data
|
||||
|
||||
view()
|
||||
|
||||
assert isinstance(received["data"], AuthData)
|
||||
assert received["data"].allowed_roles == roles
|
||||
|
||||
|
||||
def test_guard_workspace_without_roles(app):
|
||||
router = _make_router()
|
||||
received = {}
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with (
|
||||
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
|
||||
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
|
||||
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
|
||||
):
|
||||
mock_auth.return_value.authenticate.return_value = _fake_identity()
|
||||
|
||||
@router.guard_workspace(scope=Scope.FULL)
|
||||
def view(*, auth_data):
|
||||
received["data"] = auth_data
|
||||
|
||||
view()
|
||||
|
||||
assert isinstance(received["data"], AuthData)
|
||||
assert received["data"].allowed_roles is None
|
||||
|
||||
|
||||
def test_guard_no_external_identity_when_subject_email_absent(app):
|
||||
router = _make_router()
|
||||
received = {}
|
||||
|
||||
@ -2,7 +2,6 @@ import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData, ExternalIdentity
|
||||
@ -11,12 +10,9 @@ from controllers.openapi.auth.prepare import (
|
||||
load_app,
|
||||
load_app_access_mode,
|
||||
load_tenant,
|
||||
load_tenant_from_request,
|
||||
load_workspace_role,
|
||||
resolve_external_user,
|
||||
)
|
||||
from libs.oauth_bearer import TokenType
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
|
||||
def _make_auth_data(**kwargs) -> AuthData:
|
||||
@ -58,21 +54,14 @@ def test_load_app_raises_not_found_when_not_normal():
|
||||
load_app(data)
|
||||
|
||||
|
||||
def test_load_app_stashes_app_even_when_api_disabled():
|
||||
def test_load_app_raises_forbidden_when_api_disabled():
|
||||
app = MagicMock()
|
||||
app.status = "normal"
|
||||
app.enable_api = False
|
||||
data = _make_auth_data(path_params={"app_id": "abc"})
|
||||
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app):
|
||||
load_app(data)
|
||||
assert data.app is app
|
||||
|
||||
|
||||
def test_load_app_skips_when_already_set():
|
||||
existing_app = MagicMock()
|
||||
data = _make_auth_data(app=existing_app, path_params={"app_id": "abc"})
|
||||
load_app(data)
|
||||
assert data.app is existing_app
|
||||
with pytest.raises(Forbidden):
|
||||
load_app(data)
|
||||
|
||||
|
||||
def test_load_tenant_writes_tenant():
|
||||
@ -86,13 +75,6 @@ def test_load_tenant_writes_tenant():
|
||||
assert data.tenant is tenant
|
||||
|
||||
|
||||
def test_load_tenant_skips_when_already_set():
|
||||
existing_tenant = MagicMock()
|
||||
data = _make_auth_data(app=MagicMock(), tenant=existing_tenant)
|
||||
load_tenant(data)
|
||||
assert data.tenant is existing_tenant
|
||||
|
||||
|
||||
def test_load_tenant_raises_forbidden_when_archived():
|
||||
from models.account import TenantStatus
|
||||
|
||||
@ -133,13 +115,6 @@ def test_load_account_writes_caller():
|
||||
assert data.caller_kind == "account"
|
||||
|
||||
|
||||
def test_load_account_skips_when_already_set():
|
||||
existing_caller = MagicMock()
|
||||
data = _make_auth_data(account_id=uuid.uuid4(), caller=existing_caller)
|
||||
load_account(data)
|
||||
assert data.caller is existing_caller
|
||||
|
||||
|
||||
def test_load_account_sets_current_tenant_when_tenant_present():
|
||||
account = MagicMock()
|
||||
tenant = MagicMock()
|
||||
@ -206,143 +181,3 @@ def test_load_app_access_mode_no_op_when_app_missing():
|
||||
data = _make_auth_data()
|
||||
load_app_access_mode(data)
|
||||
assert data.app_access_mode is None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app():
|
||||
return Flask(__name__)
|
||||
|
||||
|
||||
def test_load_tenant_from_request_from_path_params(flask_app):
|
||||
tenant = MagicMock()
|
||||
tenant.status = "normal"
|
||||
wid = str(uuid.uuid4())
|
||||
data = _make_auth_data(path_params={"workspace_id": wid})
|
||||
with flask_app.test_request_context("/test"):
|
||||
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant):
|
||||
load_tenant_from_request(data)
|
||||
assert data.tenant is tenant
|
||||
|
||||
|
||||
def test_load_tenant_from_request_from_query_param(flask_app):
|
||||
tenant = MagicMock()
|
||||
tenant.status = "normal"
|
||||
wid = str(uuid.uuid4())
|
||||
data = _make_auth_data(path_params={})
|
||||
with flask_app.test_request_context(f"/test?workspace_id={wid}"):
|
||||
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant):
|
||||
load_tenant_from_request(data)
|
||||
assert data.tenant is tenant
|
||||
|
||||
|
||||
def test_load_tenant_from_request_skips_when_already_set(flask_app):
|
||||
existing_tenant = MagicMock()
|
||||
data = _make_auth_data(tenant=existing_tenant, path_params={})
|
||||
with flask_app.test_request_context("/test"):
|
||||
load_tenant_from_request(data)
|
||||
assert data.tenant is existing_tenant
|
||||
|
||||
|
||||
def test_load_tenant_from_request_raises_not_found_when_no_id(flask_app):
|
||||
data = _make_auth_data(path_params={})
|
||||
with flask_app.test_request_context("/test"):
|
||||
with pytest.raises(NotFound):
|
||||
load_tenant_from_request(data)
|
||||
|
||||
|
||||
def test_load_tenant_from_request_raises_not_found_when_missing(flask_app):
|
||||
wid = str(uuid.uuid4())
|
||||
data = _make_auth_data(path_params={"workspace_id": wid})
|
||||
with flask_app.test_request_context("/test"):
|
||||
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=None):
|
||||
with pytest.raises(NotFound):
|
||||
load_tenant_from_request(data)
|
||||
|
||||
|
||||
def test_load_tenant_from_request_raises_not_found_when_archived(flask_app):
|
||||
from models.account import TenantStatus
|
||||
|
||||
tenant = MagicMock()
|
||||
tenant.status = TenantStatus.ARCHIVE
|
||||
wid = str(uuid.uuid4())
|
||||
data = _make_auth_data(path_params={"workspace_id": wid})
|
||||
with flask_app.test_request_context("/test"):
|
||||
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant):
|
||||
with pytest.raises(NotFound):
|
||||
load_tenant_from_request(data)
|
||||
|
||||
|
||||
def test_load_tenant_from_request_raises_not_found_when_invalid_uuid(flask_app):
|
||||
data = _make_auth_data(path_params={"workspace_id": "not-a-uuid"})
|
||||
with flask_app.test_request_context("/test"):
|
||||
with pytest.raises(NotFound):
|
||||
load_tenant_from_request(data)
|
||||
|
||||
|
||||
# --- load_workspace_role ---
|
||||
|
||||
|
||||
def test_load_workspace_role_stashes_role():
|
||||
tenant = MagicMock()
|
||||
tenant.id = uuid.uuid4()
|
||||
caller = MagicMock()
|
||||
caller.status = "active"
|
||||
data = _make_auth_data(account_id=uuid.uuid4(), tenant=tenant, caller=caller)
|
||||
with patch(
|
||||
"controllers.openapi.auth.prepare.TenantService.get_account_role_in_tenant",
|
||||
return_value=TenantAccountRole.ADMIN,
|
||||
):
|
||||
load_workspace_role(data)
|
||||
assert data.tenant_role == TenantAccountRole.ADMIN
|
||||
|
||||
|
||||
def test_load_workspace_role_none_when_not_member():
|
||||
tenant = MagicMock()
|
||||
tenant.id = uuid.uuid4()
|
||||
caller = MagicMock()
|
||||
caller.status = "active"
|
||||
data = _make_auth_data(account_id=uuid.uuid4(), tenant=tenant, caller=caller)
|
||||
with patch(
|
||||
"controllers.openapi.auth.prepare.TenantService.get_account_role_in_tenant",
|
||||
return_value=None,
|
||||
):
|
||||
load_workspace_role(data)
|
||||
assert data.tenant_role is None
|
||||
|
||||
|
||||
def test_load_workspace_role_none_when_account_inactive():
|
||||
tenant = MagicMock()
|
||||
tenant.id = uuid.uuid4()
|
||||
caller = MagicMock()
|
||||
caller.status = "banned"
|
||||
data = _make_auth_data(account_id=uuid.uuid4(), tenant=tenant, caller=caller)
|
||||
load_workspace_role(data)
|
||||
assert data.tenant_role is None
|
||||
|
||||
|
||||
def test_load_workspace_role_skips_when_already_set():
|
||||
tenant = MagicMock()
|
||||
tenant.id = uuid.uuid4()
|
||||
caller = MagicMock()
|
||||
caller.status = "active"
|
||||
data = _make_auth_data(
|
||||
account_id=uuid.uuid4(),
|
||||
tenant=tenant,
|
||||
caller=caller,
|
||||
tenant_role=TenantAccountRole.OWNER,
|
||||
)
|
||||
load_workspace_role(data)
|
||||
assert data.tenant_role == TenantAccountRole.OWNER
|
||||
|
||||
|
||||
def test_load_workspace_role_skips_when_tenant_missing():
|
||||
data = _make_auth_data(account_id=uuid.uuid4())
|
||||
load_workspace_role(data)
|
||||
assert data.tenant_role is None
|
||||
|
||||
|
||||
def test_load_workspace_role_skips_when_account_id_missing():
|
||||
tenant = MagicMock()
|
||||
data = _make_auth_data(tenant=tenant, account_id=None)
|
||||
load_workspace_role(data)
|
||||
assert data.tenant_role is None
|
||||
|
||||
330
api/tests/unit_tests/controllers/openapi/auth/test_role_gate.py
Normal file
330
api/tests/unit_tests/controllers/openapi/auth/test_role_gate.py
Normal file
@ -0,0 +1,330 @@
|
||||
"""Role-gate tests.
|
||||
|
||||
The decorator wraps `validate_bearer` + `accept_subjects` and must:
|
||||
- 404 when caller is not a member of ``workspace_id`` (parity with
|
||||
`GET /openapi/v1/workspaces/<id>`; prevents tenant-id existence leak)
|
||||
- 403 when caller IS a member but their role is not in the allowed set
|
||||
- pass through when role matches (or when no role restriction given)
|
||||
- raise RuntimeError on missing auth context / account_id / workspace_id —
|
||||
those are wiring bugs, not user-driven failures
|
||||
|
||||
Identity is read from the openapi auth ContextVar — the slot
|
||||
`validate_bearer` publishes — so these tests seed it via `_seed`
|
||||
(``set_auth_ctx``), NOT ``flask.g``. `test_seeding_only_flask_g_*`
|
||||
locks in that ``g`` is *not* a valid identity source.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.openapi.auth.role_gate import require_workspace_role
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
# Tokens from `_seed`'s `set_auth_ctx` calls, drained after each test so a
|
||||
# published identity can't leak into the next (the ContextVar is module-global
|
||||
# and worker threads are reused). Seed via `_seed(...)`, never `flask.g`.
|
||||
_seed_tokens: list = []
|
||||
|
||||
|
||||
def _seed(ctx: AuthContext) -> None:
|
||||
_seed_tokens.append(set_auth_ctx(ctx))
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_auth_ctx():
|
||||
yield
|
||||
while _seed_tokens:
|
||||
reset_auth_ctx(_seed_tokens.pop())
|
||||
|
||||
|
||||
def _account_ctx(account_id: uuid.UUID | None = None) -> AuthContext:
|
||||
return AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
subject_email="user@example.com",
|
||||
subject_issuer="dify:account",
|
||||
account_id=account_id or uuid.uuid4(),
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_id=uuid.uuid4(),
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
)
|
||||
|
||||
|
||||
def _sso_ctx() -> AuthContext:
|
||||
return AuthContext(
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
subject_email="sso@partner.com",
|
||||
subject_issuer="https://idp.partner.com",
|
||||
account_id=None,
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.APPS_RUN}),
|
||||
token_id=uuid.uuid4(),
|
||||
token_type=TokenType.OAUTH_EXTERNAL_SSO,
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="h2",
|
||||
verified_tenants={},
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _stub_role(role: TenantAccountRole | None):
|
||||
"""Stub the service-layer membership lookup the gate delegates to.
|
||||
|
||||
The gate no longer issues SQL itself — it calls
|
||||
``TenantService.get_account_role_in_tenant`` and acts purely on the
|
||||
returned role (``None`` → non-member). These tests pin that behaviour;
|
||||
the query itself is covered in ``TestTenantService``.
|
||||
"""
|
||||
with patch(
|
||||
"controllers.openapi.auth.role_gate.TenantService.get_account_role_in_tenant",
|
||||
return_value=role,
|
||||
) as mocked:
|
||||
yield mocked
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-member → 404
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_non_member_gets_404():
|
||||
app = Flask(__name__)
|
||||
workspace_id = str(uuid.uuid4())
|
||||
|
||||
@require_workspace_role()
|
||||
def view(workspace_id: str) -> str:
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"):
|
||||
_seed(_account_ctx())
|
||||
with _stub_role(None):
|
||||
with pytest.raises(NotFound):
|
||||
view(workspace_id=workspace_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Member with insufficient role → 403
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_normal_member_blocked_when_admin_required():
|
||||
app = Flask(__name__)
|
||||
workspace_id = str(uuid.uuid4())
|
||||
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def view(workspace_id: str) -> str:
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/members"):
|
||||
_seed(_account_ctx())
|
||||
with _stub_role(TenantAccountRole.NORMAL):
|
||||
with pytest.raises(Forbidden):
|
||||
view(workspace_id=workspace_id)
|
||||
|
||||
|
||||
def test_editor_blocked_when_admin_required():
|
||||
app = Flask(__name__)
|
||||
workspace_id = str(uuid.uuid4())
|
||||
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def view(workspace_id: str) -> str:
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/members"):
|
||||
_seed(_account_ctx())
|
||||
with _stub_role(TenantAccountRole.EDITOR):
|
||||
with pytest.raises(Forbidden):
|
||||
view(workspace_id=workspace_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Member with allowed role → pass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_admin_passes_when_admin_required():
|
||||
app = Flask(__name__)
|
||||
workspace_id = str(uuid.uuid4())
|
||||
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def view(workspace_id: str) -> str:
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/members"):
|
||||
_seed(_account_ctx())
|
||||
with _stub_role(TenantAccountRole.ADMIN):
|
||||
assert view(workspace_id=workspace_id) == "ok"
|
||||
|
||||
|
||||
def test_owner_passes_when_admin_required():
|
||||
app = Flask(__name__)
|
||||
workspace_id = str(uuid.uuid4())
|
||||
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def view(workspace_id: str) -> str:
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/members"):
|
||||
_seed(_account_ctx())
|
||||
with _stub_role(TenantAccountRole.OWNER):
|
||||
assert view(workspace_id=workspace_id) == "ok"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Membership-only (no role restriction)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_membership_only_passes_for_any_role():
|
||||
app = Flask(__name__)
|
||||
workspace_id = str(uuid.uuid4())
|
||||
|
||||
@require_workspace_role()
|
||||
def view(workspace_id: str) -> str:
|
||||
return "ok"
|
||||
|
||||
for role in (
|
||||
TenantAccountRole.OWNER,
|
||||
TenantAccountRole.ADMIN,
|
||||
TenantAccountRole.EDITOR,
|
||||
TenantAccountRole.NORMAL,
|
||||
TenantAccountRole.DATASET_OPERATOR,
|
||||
):
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"):
|
||||
_seed(_account_ctx())
|
||||
with _stub_role(role):
|
||||
assert view(workspace_id=workspace_id) == "ok"
|
||||
|
||||
|
||||
def test_membership_only_still_404s_non_member():
|
||||
app = Flask(__name__)
|
||||
workspace_id = str(uuid.uuid4())
|
||||
|
||||
@require_workspace_role()
|
||||
def view(workspace_id: str) -> str:
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"):
|
||||
_seed(_account_ctx())
|
||||
with _stub_role(None):
|
||||
with pytest.raises(NotFound):
|
||||
view(workspace_id=workspace_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lookup is scoped to the caller's account_id and the URL workspace_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_lookup_is_scoped_to_caller_and_workspace():
|
||||
"""The decorator must delegate the lookup keyed on
|
||||
`(caller's account_id, URL workspace_id)` — otherwise a member of
|
||||
workspace A could quietly hit endpoints for workspace B. Assert the
|
||||
exact arguments handed to the service; the SQL those arguments compile
|
||||
to is pinned in ``TestTenantService.test_get_account_role_in_tenant_*``.
|
||||
"""
|
||||
|
||||
app = Flask(__name__)
|
||||
account_id = uuid.uuid4()
|
||||
workspace_id = str(uuid.uuid4())
|
||||
|
||||
@require_workspace_role()
|
||||
def view(workspace_id: str) -> str:
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"):
|
||||
_seed(_account_ctx(account_id=account_id))
|
||||
with _stub_role(TenantAccountRole.NORMAL) as mocked:
|
||||
view(workspace_id=workspace_id)
|
||||
|
||||
_session, passed_account_id, passed_workspace_id = mocked.call_args.args
|
||||
assert passed_account_id == str(account_id)
|
||||
assert passed_workspace_id == workspace_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wiring bugs surface as RuntimeError (loud), not 403 (silent)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_missing_auth_ctx_is_runtime_error():
|
||||
app = Flask(__name__)
|
||||
workspace_id = str(uuid.uuid4())
|
||||
|
||||
@require_workspace_role()
|
||||
def view(workspace_id: str) -> str:
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"):
|
||||
with pytest.raises(RuntimeError):
|
||||
view(workspace_id=workspace_id)
|
||||
|
||||
|
||||
def test_seeding_only_flask_g_does_not_satisfy_gate():
|
||||
"""Regression — pins the identity source to the ContextVar, not ``flask.g``.
|
||||
|
||||
Production fills the ContextVar (``validate_bearer`` → ``set_auth_ctx``)
|
||||
and never touches ``g.auth_ctx``. An earlier revision of this gate read
|
||||
``g.auth_ctx``, so every real request raised RuntimeError → 500 while the
|
||||
suite stayed green (it seeded ``g`` directly). Here we seed ONLY ``g`` and
|
||||
leave the ContextVar empty: the gate must still raise, proving it does not
|
||||
accept ``g`` as an identity source. Reading ``g`` again would let the
|
||||
membership lookup run (stubbed to succeed) and this would fail.
|
||||
"""
|
||||
from flask import g
|
||||
|
||||
app = Flask(__name__)
|
||||
workspace_id = str(uuid.uuid4())
|
||||
|
||||
@require_workspace_role()
|
||||
def view(workspace_id: str) -> str:
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"):
|
||||
g.auth_ctx = _account_ctx() # the wrong slot — must be ignored
|
||||
with _stub_role(TenantAccountRole.OWNER):
|
||||
with pytest.raises(RuntimeError):
|
||||
view(workspace_id=workspace_id)
|
||||
|
||||
|
||||
def test_sso_caller_is_runtime_error():
|
||||
"""External SSO context has account_id=None — the caller stacked the
|
||||
role gate without `accept_subjects(SubjectType.ACCOUNT)`. That's a
|
||||
wiring bug, surface it as RuntimeError rather than 404 the SSO user."""
|
||||
|
||||
app = Flask(__name__)
|
||||
workspace_id = str(uuid.uuid4())
|
||||
|
||||
@require_workspace_role()
|
||||
def view(workspace_id: str) -> str:
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"):
|
||||
_seed(_sso_ctx())
|
||||
with pytest.raises(RuntimeError):
|
||||
view(workspace_id=workspace_id)
|
||||
|
||||
|
||||
def test_missing_workspace_id_kwarg_is_runtime_error():
|
||||
app = Flask(__name__)
|
||||
|
||||
@require_workspace_role()
|
||||
def view() -> str:
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context("/openapi/v1/foo"):
|
||||
_seed(_account_ctx())
|
||||
with pytest.raises(RuntimeError):
|
||||
view()
|
||||
@ -2,22 +2,18 @@ import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.verify import (
|
||||
check_acl,
|
||||
check_app_access,
|
||||
check_app_api_enabled,
|
||||
check_membership,
|
||||
check_private_app_permission,
|
||||
check_scope,
|
||||
check_workspace_mismatch,
|
||||
check_workspace_role,
|
||||
)
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import Tenant, TenantAccountRole
|
||||
from models.account import Tenant
|
||||
from models.model import App
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
@ -144,92 +140,3 @@ def test_check_private_app_permission_passes_when_allowed():
|
||||
target = "controllers.openapi.auth.verify.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp"
|
||||
with patch(target, return_value=True):
|
||||
check_private_app_permission(data)
|
||||
|
||||
|
||||
# --- check_workspace_mismatch ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app():
|
||||
return Flask(__name__)
|
||||
|
||||
|
||||
def test_check_workspace_mismatch_passes_when_tenant_none(flask_app):
|
||||
with flask_app.test_request_context("/test"):
|
||||
check_workspace_mismatch(_data(tenant=None))
|
||||
|
||||
|
||||
def test_check_workspace_mismatch_passes_when_ids_match(flask_app):
|
||||
tenant = MagicMock(spec=Tenant)
|
||||
tid = uuid.uuid4()
|
||||
tenant.id = tid
|
||||
with flask_app.test_request_context(f"/test?workspace_id={tid}"):
|
||||
check_workspace_mismatch(_data(tenant=tenant, path_params={}))
|
||||
|
||||
|
||||
def test_check_workspace_mismatch_raises_422_on_mismatch(flask_app):
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
tenant = MagicMock(spec=Tenant)
|
||||
tenant.id = uuid.uuid4()
|
||||
other_id = uuid.uuid4()
|
||||
with flask_app.test_request_context(f"/test?workspace_id={other_id}"):
|
||||
with pytest.raises(UnprocessableEntity):
|
||||
check_workspace_mismatch(_data(tenant=tenant, path_params={}))
|
||||
|
||||
|
||||
def test_check_workspace_mismatch_passes_when_no_request_workspace_id(flask_app):
|
||||
tenant = MagicMock(spec=Tenant)
|
||||
tenant.id = uuid.uuid4()
|
||||
with flask_app.test_request_context("/test"):
|
||||
check_workspace_mismatch(_data(tenant=tenant, path_params={}))
|
||||
|
||||
|
||||
# --- check_workspace_role ---
|
||||
|
||||
|
||||
def test_check_workspace_role_passes_when_allowed_roles_none():
|
||||
check_workspace_role(_data(allowed_roles=None))
|
||||
|
||||
|
||||
def test_check_workspace_role_raises_not_found_when_not_member():
|
||||
data = _data(tenant_role=None, allowed_roles=frozenset({TenantAccountRole.ADMIN}))
|
||||
with pytest.raises(NotFound):
|
||||
check_workspace_role(data)
|
||||
|
||||
|
||||
def test_check_workspace_role_raises_forbidden_when_wrong_role():
|
||||
data = _data(
|
||||
tenant_role=TenantAccountRole.EDITOR,
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER}),
|
||||
)
|
||||
with pytest.raises(Forbidden, match="insufficient workspace role"):
|
||||
check_workspace_role(data)
|
||||
|
||||
|
||||
def test_check_workspace_role_passes_when_role_allowed():
|
||||
data = _data(
|
||||
tenant_role=TenantAccountRole.ADMIN,
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
check_workspace_role(data)
|
||||
|
||||
|
||||
# --- check_app_api_enabled ---
|
||||
|
||||
|
||||
def test_check_app_api_enabled_passes_when_enabled():
|
||||
app = MagicMock(spec=App)
|
||||
app.enable_api = True
|
||||
check_app_api_enabled(_data(app=app))
|
||||
|
||||
|
||||
def test_check_app_api_enabled_raises_forbidden_when_disabled():
|
||||
app = MagicMock(spec=App)
|
||||
app.enable_api = False
|
||||
with pytest.raises(Forbidden, match="service_api_disabled"):
|
||||
check_app_api_enabled(_data(app=app))
|
||||
|
||||
|
||||
def test_check_app_api_enabled_passes_when_app_none():
|
||||
check_app_api_enabled(_data(app=None))
|
||||
|
||||
@ -9,18 +9,7 @@ from controllers.openapi.auth.pipeline import PipelineRouter
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
|
||||
|
||||
def _stub_execute(
|
||||
self,
|
||||
args,
|
||||
kwargs,
|
||||
view,
|
||||
*,
|
||||
scope=None,
|
||||
allowed_token_types=None,
|
||||
edition=None,
|
||||
workspace_membership=False,
|
||||
allowed_roles=None,
|
||||
):
|
||||
def _stub_execute(self, args, kwargs, view, *, scope=None, allowed_token_types=None, edition=None):
|
||||
"""Bypass all auth logic; inject minimal AuthData and call the view directly."""
|
||||
kwargs["auth_data"] = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
@ -29,7 +18,6 @@ def _stub_execute(
|
||||
token_id=uuid.uuid4(),
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
required_scope=scope,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@ -9,10 +9,9 @@ Coverage:
|
||||
|
||||
Auth-pipeline plumbing is bypassed via the `bypass_pipeline` fixture from
|
||||
conftest.py; the bearer identity is seeded into the openapi auth ContextVar
|
||||
via `_seed` (the slot `validate_bearer` publishes). Tests that exercise
|
||||
endpoint *bodies* skip the single `guard_workspace` decorator via
|
||||
``__wrapped__`` — membership and role enforcement live in the auth pipeline
|
||||
and are covered in `auth/test_prepare.py` and `auth/test_verify.py`.
|
||||
via `_seed` (the slot `validate_bearer` publishes), and the role gate's DB
|
||||
lookup is mocked. Tests that exercise endpoint *bodies* skip the decorators
|
||||
via ``__wrapped__`` since those layers are covered in `auth/test_role_gate.py`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -269,7 +268,7 @@ def test_switch_returns_workspace_detail_with_current_true(app, bypass_pipeline,
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/switch", method="POST"):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
body, status = api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
assert status == 200
|
||||
assert body["id"] == ws_id
|
||||
@ -297,7 +296,7 @@ def test_switch_404s_when_service_raises_account_not_link_tenant(app, bypass_pip
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/switch", method="POST"):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(NotFound):
|
||||
api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -331,7 +330,7 @@ def test_members_list_returns_normalized_rows(app, bypass_pipeline, monkeypatch)
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members"):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
body, status = api.get.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
body, status = api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
assert status == 200
|
||||
assert body["page"] == 1
|
||||
@ -373,7 +372,7 @@ def test_members_list_paginates_with_query_params(app, bypass_pipeline, monkeypa
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members?page=2&limit=2"):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
body, status = api.get.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
body, status = api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
assert status == 200
|
||||
assert body["page"] == 2
|
||||
@ -396,7 +395,7 @@ def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypa
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members?pg=2"):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(BadRequest):
|
||||
api.get.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -434,7 +433,7 @@ def test_invite_happy_path_returns_invite_url_and_member_id(app, bypass_pipeline
|
||||
content_type="application/json",
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
body, status = api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
assert status == 201
|
||||
assert body["result"] == "success"
|
||||
@ -519,7 +518,7 @@ def test_invite_blocked_by_saas_members_cap(app, bypass_pipeline, monkeypatch):
|
||||
with _invite_request(app, ws_id, acct_id):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(Forbidden) as exc_info:
|
||||
api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
body = exc_info.value.response.json
|
||||
assert body["code"] == "members.limit_exceeded"
|
||||
@ -565,7 +564,7 @@ def test_invite_blocked_by_ee_workspace_members_license(app, bypass_pipeline, mo
|
||||
with _invite_request(app, ws_id, acct_id):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(Forbidden) as exc_info:
|
||||
api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
body = exc_info.value.response.json
|
||||
assert body["code"] == "workspace_members.license_exceeded"
|
||||
@ -604,7 +603,7 @@ def test_invite_ce_passes_when_both_caps_disabled(app, bypass_pipeline, monkeypa
|
||||
|
||||
with _invite_request(app, ws_id, acct_id):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
body, status = api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
assert status == 201
|
||||
assert body["email"] == "new@example.com"
|
||||
@ -633,7 +632,7 @@ def test_invite_400_when_already_in_tenant(app, bypass_pipeline, monkeypatch):
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(BadRequest):
|
||||
api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -666,7 +665,7 @@ def test_delete_member_happy_path(app, bypass_pipeline, monkeypatch):
|
||||
method="DELETE",
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
body, status = api.delete.__wrapped__(
|
||||
body, status = api.delete.__wrapped__.__wrapped__(
|
||||
api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id)
|
||||
)
|
||||
|
||||
@ -708,7 +707,7 @@ def test_delete_member_exception_mapping(app, bypass_pipeline, monkeypatch, exc,
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(expected):
|
||||
api.delete.__wrapped__(
|
||||
api.delete.__wrapped__.__wrapped__(
|
||||
api,
|
||||
workspace_id=ws_id,
|
||||
member_id=member_id,
|
||||
@ -735,7 +734,7 @@ def test_delete_member_404_when_member_missing(app, bypass_pipeline, monkeypatch
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(NotFound):
|
||||
api.delete.__wrapped__(
|
||||
api.delete.__wrapped__.__wrapped__(
|
||||
api,
|
||||
workspace_id=ws_id,
|
||||
member_id=member_id,
|
||||
@ -775,7 +774,9 @@ def test_update_role_happy_path(app, bypass_pipeline, monkeypatch):
|
||||
content_type="application/json",
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
body, status = api.put.__wrapped__(api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id))
|
||||
body, status = api.put.__wrapped__.__wrapped__(
|
||||
api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id)
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
assert body == {"result": "success"}
|
||||
@ -819,7 +820,7 @@ def test_update_role_exception_mapping(app, bypass_pipeline, monkeypatch, exc, e
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(expected):
|
||||
api.put.__wrapped__(
|
||||
api.put.__wrapped__.__wrapped__(
|
||||
api,
|
||||
workspace_id=ws_id,
|
||||
member_id=member_id,
|
||||
@ -827,6 +828,44 @@ def test_update_role_exception_mapping(app, bypass_pipeline, monkeypatch, exc, e
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Role gate composition — non-member sees 404 even with valid bearer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_non_member_caller_gets_404_on_switch(app, bypass_pipeline, monkeypatch):
|
||||
"""End-to-end: caller has valid account bearer but no membership in
|
||||
the requested workspace. The role gate must short-circuit to 404
|
||||
before any TenantService method is touched."""
|
||||
ws_id = str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
api = WorkspaceSwitchApi()
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
switch_mock = Mock()
|
||||
monkeypatch.setattr(
|
||||
sys.modules["controllers.openapi.workspaces"],
|
||||
"TenantService",
|
||||
_tenant_service(switch_tenant=switch_mock),
|
||||
)
|
||||
monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db)
|
||||
monkeypatch.setattr(sys.modules["controllers.openapi.auth.role_gate"], "db", mock_db)
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/switch", method="POST"):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
# Strip only the bearer + surface-gate wrappers; keep the role gate.
|
||||
# Decorator stack (innermost → outermost):
|
||||
# role_gate → accept_subjects → validate_bearer
|
||||
# `post.__wrapped__` is now the role-gate wrapper directly (auth_router.guard is the only outer wrapper).
|
||||
gated = api.post.__wrapped__
|
||||
with pytest.raises(NotFound):
|
||||
gated(api, workspace_id=ws_id)
|
||||
|
||||
switch_mock.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _load_tenant rejects archived tenant
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -852,7 +891,7 @@ def test_load_tenant_rejects_archived_workspace(app, bypass_pipeline, monkeypatc
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members"):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(NotFound):
|
||||
api.get.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -886,4 +925,4 @@ def test_invite_400_when_register_error(app, bypass_pipeline, monkeypatch):
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(BadRequest):
|
||||
api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
@ -638,8 +638,8 @@ class TestTenantService:
|
||||
callable_func(*args, **kwargs)
|
||||
|
||||
# ==================== get_account_role_in_tenant Tests ====================
|
||||
# Backs the auth pipeline's `load_workspace_role`: None => non-member
|
||||
# (pipeline maps to 404), otherwise the caller's role (out-of-set role => 403).
|
||||
# Backs `require_workspace_role`: None => non-member (gate maps to 404),
|
||||
# otherwise the caller's role (gate maps an out-of-set role to 403).
|
||||
|
||||
def test_get_account_role_in_tenant_returns_role_for_member(self):
|
||||
"""A row in TenantAccountJoin yields the caller's role."""
|
||||
|
||||
1
api/uv.lock
generated
1
api/uv.lock
generated
@ -1300,7 +1300,6 @@ requires-dist = [
|
||||
{ name = "pydantic-ai-slim", extras = ["anthropic", "google", "openai"], marker = "extra == 'server'", specifier = ">=1.85.1,<2.0.0" },
|
||||
{ name = "pydantic-settings", marker = "extra == 'server'", specifier = ">=2.12.0,<3.0.0" },
|
||||
{ name = "redis", marker = "extra == 'server'", specifier = ">=7.4.0,<8.0.0" },
|
||||
{ name = "shell-session-manager", marker = "extra == 'server'", specifier = "==2.1.1" },
|
||||
{ name = "typing-extensions", specifier = ">=4.12.2,<5.0.0" },
|
||||
{ name = "uvicorn", extras = ["standard"], marker = "extra == 'server'", specifier = "==0.46.0" },
|
||||
]
|
||||
|
||||
6
cli/.gitignore
vendored
6
cli/.gitignore
vendored
@ -4,4 +4,8 @@ node_modules/
|
||||
*.tsbuildinfo
|
||||
.vitest-cache/
|
||||
docs/specs/
|
||||
context/
|
||||
context/
|
||||
test/**/*.ts.map
|
||||
test/**/*.js.map
|
||||
test/**/*.js
|
||||
test/**/*.d.ts
|
||||
|
||||
@ -2,7 +2,7 @@ import type { StubServer } from '@test/fixtures/stub-server'
|
||||
import { testHttpClient } from '@test/fixtures/http-client'
|
||||
import { jsonResponder, startStubServer } from '@test/fixtures/stub-server'
|
||||
import { afterEach, describe, expect, it } from 'vitest'
|
||||
import { isBaseError } from '@/errors/base'
|
||||
import { isHttpClientError } from '@/errors/base'
|
||||
import { AccountSessionsClient } from './account-sessions.js'
|
||||
|
||||
const LIST_BODY = { page: 1, limit: 100, total: 0, has_more: false, data: [] }
|
||||
@ -70,7 +70,7 @@ describe('AccountSessionsClient.revoke', () => {
|
||||
jsonResponder(404, { error: { code: 'not_found', message: 'session not found' } }, cap))
|
||||
|
||||
await expect(makeClient(stub.url).revoke('missing')).rejects.toSatisfy(
|
||||
err => isBaseError(err) && err.httpStatus === 404,
|
||||
err => isHttpClientError(err) && err.httpStatus === 404,
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import type { StubServer } from '@test/fixtures/stub-server'
|
||||
import { testHttpClient } from '@test/fixtures/http-client'
|
||||
import { jsonResponder, startStubServer } from '@test/fixtures/stub-server'
|
||||
import { afterEach, describe, expect, it } from 'vitest'
|
||||
import { isBaseError } from '@/errors/base'
|
||||
import { isHttpClientError } from '@/errors/base'
|
||||
import { AccountClient } from './account.js'
|
||||
|
||||
function makeClient(host: string): AccountClient {
|
||||
@ -35,7 +35,7 @@ describe('AccountClient.get', () => {
|
||||
stub = await startStubServer(cap => jsonResponder(401, { error: 'expired' }, cap))
|
||||
|
||||
await expect(makeClient(stub.url).get()).rejects.toSatisfy(
|
||||
err => isBaseError(err) && err.httpStatus === 401,
|
||||
err => isHttpClientError(err) && err.httpStatus === 401,
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@ -2,7 +2,7 @@ import type { StubServer } from '@test/fixtures/stub-server'
|
||||
import { testHttpClient } from '@test/fixtures/http-client'
|
||||
import { jsonResponder, startStubServer } from '@test/fixtures/stub-server'
|
||||
import { afterEach, describe, expect, it } from 'vitest'
|
||||
import { isBaseError } from '@/errors/base'
|
||||
import { isHttpClientError } from '@/errors/base'
|
||||
import { AppsClient } from './apps.js'
|
||||
|
||||
const LIST_BODY = { page: 1, limit: 20, total: 0, has_more: false, data: [] }
|
||||
@ -74,7 +74,7 @@ describe('AppsClient.list', () => {
|
||||
stub = await startStubServer(cap => jsonResponder(403, { error: 'forbidden' }, cap))
|
||||
|
||||
await expect(makeClient(stub.url).list({ workspaceId: 'ws-1' })).rejects.toSatisfy(
|
||||
err => isBaseError(err) && err.httpStatus === 403,
|
||||
err => isHttpClientError(err) && err.httpStatus === 403,
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@ -5,7 +5,7 @@ import { join } from 'node:path'
|
||||
import { testHttpClient } from '@test/fixtures/http-client'
|
||||
import { jsonResponder, startStubServer } from '@test/fixtures/stub-server'
|
||||
import { afterEach, beforeEach, describe, expect, it } from 'vitest'
|
||||
import { isBaseError } from '@/errors/base'
|
||||
import { isHttpClientError } from '@/errors/base'
|
||||
import { FileUploadClient } from './file-upload.js'
|
||||
|
||||
const UPLOADED = {
|
||||
@ -70,7 +70,7 @@ describe('FileUploadClient.upload', () => {
|
||||
stub = await startStubServer(cap => jsonResponder(413, { error: 'file too large' }, cap))
|
||||
|
||||
await expect(makeClient(stub.url).upload('app-1', filePath)).rejects.toSatisfy(
|
||||
err => isBaseError(err) && err.httpStatus === 413,
|
||||
err => isHttpClientError(err) && err.httpStatus === 413,
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@ -2,7 +2,7 @@ import type { StubServer } from '@test/fixtures/stub-server'
|
||||
import { testHttpClient } from '@test/fixtures/http-client'
|
||||
import { jsonResponder, startStubServer } from '@test/fixtures/stub-server'
|
||||
import { afterEach, describe, expect, it } from 'vitest'
|
||||
import { isBaseError } from '@/errors/base'
|
||||
import { isHttpClientError } from '@/errors/base'
|
||||
import { MembersClient } from './members.js'
|
||||
import { WorkspacesClient } from './workspaces.js'
|
||||
|
||||
@ -62,7 +62,7 @@ describe('MembersClient.list', () => {
|
||||
stub = await startStubServer(cap => jsonResponder(403, { error: 'forbidden' }, cap))
|
||||
|
||||
await expect(makeClient(stub.url).list('ws-1')).rejects.toSatisfy(
|
||||
err => isBaseError(err) && err.httpStatus === 403,
|
||||
err => isHttpClientError(err) && err.httpStatus === 403,
|
||||
)
|
||||
})
|
||||
|
||||
@ -70,7 +70,7 @@ describe('MembersClient.list', () => {
|
||||
stub = await startStubServer(cap => jsonResponder(404, { error: 'not found' }, cap))
|
||||
|
||||
await expect(makeClient(stub.url).list('ws-missing')).rejects.toSatisfy(
|
||||
err => isBaseError(err) && err.httpStatus === 404,
|
||||
err => isHttpClientError(err) && err.httpStatus === 404,
|
||||
)
|
||||
})
|
||||
})
|
||||
@ -117,7 +117,7 @@ describe('MembersClient.invite', () => {
|
||||
|
||||
await expect(
|
||||
makeClient(stub.url).invite('ws-1', { email: 'u@e.com', role: 'normal' }),
|
||||
).rejects.toSatisfy(err => isBaseError(err) && err.httpStatus === 400)
|
||||
).rejects.toSatisfy(err => isHttpClientError(err) && err.httpStatus === 400)
|
||||
})
|
||||
})
|
||||
|
||||
@ -142,7 +142,7 @@ describe('MembersClient.remove', () => {
|
||||
stub = await startStubServer(cap => jsonResponder(400, { error: 'cannot operate self' }, cap))
|
||||
|
||||
await expect(makeClient(stub.url).remove('ws-1', 'm-1')).rejects.toSatisfy(
|
||||
err => isBaseError(err) && err.httpStatus === 400,
|
||||
err => isHttpClientError(err) && err.httpStatus === 400,
|
||||
)
|
||||
})
|
||||
})
|
||||
@ -170,7 +170,7 @@ describe('MembersClient.updateRole', () => {
|
||||
|
||||
await expect(
|
||||
makeClient(stub.url).updateRole('ws-1', 'm-1', { role: 'admin' }),
|
||||
).rejects.toSatisfy(err => isBaseError(err) && err.httpStatus === 400)
|
||||
).rejects.toSatisfy(err => isHttpClientError(err) && err.httpStatus === 400)
|
||||
})
|
||||
})
|
||||
|
||||
@ -209,7 +209,7 @@ describe('WorkspacesClient.switch (integration with stub)', () => {
|
||||
|
||||
const client = new WorkspacesClient(testHttpClient(stub.url, 'dfoa_test'))
|
||||
await expect(client.switch('ws-x')).rejects.toSatisfy(
|
||||
err => isBaseError(err) && err.httpStatus === 404,
|
||||
err => isHttpClientError(err) && err.httpStatus === 404,
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import type { HttpClient } from '@/http/types'
|
||||
import { BaseError } from '@/errors/base'
|
||||
import { BaseError, HttpClientError } from '@/errors/base'
|
||||
import { ErrorCode } from '@/errors/codes'
|
||||
|
||||
export const DEFAULT_CLIENT_ID = 'difyctl'
|
||||
@ -80,7 +80,7 @@ export class DeviceFlowApi {
|
||||
if (res.status === 404)
|
||||
throw versionSkew()
|
||||
if (!res.ok) {
|
||||
throw new BaseError({
|
||||
throw new HttpClientError({
|
||||
code: ErrorCode.Server4xxOther,
|
||||
message: `device/code: HTTP ${res.status}`,
|
||||
httpStatus: res.status,
|
||||
@ -133,8 +133,8 @@ export class DeviceFlowApi {
|
||||
}
|
||||
}
|
||||
|
||||
function versionSkew(): BaseError {
|
||||
return new BaseError({
|
||||
function versionSkew(): HttpClientError {
|
||||
return new HttpClientError({
|
||||
code: ErrorCode.UnsupportedEndpoint,
|
||||
message: 'this Dify host does not implement the OAuth device flow',
|
||||
httpStatus: 404,
|
||||
|
||||
@ -2,7 +2,7 @@ import type { StubServer } from '@test/fixtures/stub-server'
|
||||
import { testHttpClient } from '@test/fixtures/http-client'
|
||||
import { jsonResponder, startStubServer } from '@test/fixtures/stub-server'
|
||||
import { afterEach, describe, expect, it } from 'vitest'
|
||||
import { isBaseError } from '@/errors/base'
|
||||
import { isHttpClientError } from '@/errors/base'
|
||||
import { WorkspacesClient } from './workspaces.js'
|
||||
|
||||
// WorkspacesClient.switch is covered in members.test.ts; this file covers list().
|
||||
@ -30,14 +30,14 @@ describe('WorkspacesClient.list', () => {
|
||||
|
||||
expect(stub.captured.method).toBe('GET')
|
||||
expect(stub.captured.url).toBe('/openapi/v1/workspaces')
|
||||
expect(res.workspaces[0].id).toBe('ws-1')
|
||||
expect(res.workspaces[0]?.id).toBe('ws-1')
|
||||
})
|
||||
|
||||
it('maps 401 to a classified BaseError', async () => {
|
||||
stub = await startStubServer(cap => jsonResponder(401, { error: 'expired' }, cap))
|
||||
|
||||
await expect(makeClient(stub.url).list()).rejects.toSatisfy(
|
||||
err => isBaseError(err) && err.httpStatus === 401,
|
||||
err => isHttpClientError(err) && err.httpStatus === 401,
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@ -357,14 +357,14 @@ describe('runApp', () => {
|
||||
// warm cache with successful run
|
||||
await runApp(
|
||||
{ appId: 'app-1', message: 'hi' },
|
||||
{ bundle: bundle(), http: testHttpClient(mock.url, 'dfoa_test'), host: mock.url, io, cache },
|
||||
{ active: active(), http: testHttpClient(mock.url, 'dfoa_test'), host: mock.url, io, cache },
|
||||
)
|
||||
expect(cache.get(mock.url, 'app-1')).toBeDefined()
|
||||
|
||||
mock.setScenario('run-422-stale')
|
||||
const err = await runApp(
|
||||
{ appId: 'app-1', message: 'hi' },
|
||||
{ bundle: bundle(), http: testHttpClient(mock.url, { bearer: 'dfoa_test', retryAttempts: 0 }), host: mock.url, io, cache },
|
||||
{ active: active(), http: testHttpClient(mock.url, { bearer: 'dfoa_test', retryAttempts: 0 }), host: mock.url, io, cache },
|
||||
).catch((e: unknown) => e)
|
||||
expect(err).toMatchObject({ code: 'server_4xx_other', httpStatus: 422 })
|
||||
expect((err as { hint?: string }).hint).toMatch(/cache cleared/)
|
||||
|
||||
@ -7,7 +7,7 @@ import { AppRunClient } from '@/api/app-run'
|
||||
import { AppsClient } from '@/api/apps'
|
||||
import { FileUploadClient } from '@/api/file-upload'
|
||||
import { pickStrategy } from '@/commands/run/app/_strategies/index'
|
||||
import { BaseError } from '@/errors/base'
|
||||
import { BaseError, HttpClientError } from '@/errors/base'
|
||||
import { ErrorCode } from '@/errors/codes'
|
||||
import { getEnv, processExit } from '@/sys/index'
|
||||
import { FieldInfo } from '@/types/app-meta'
|
||||
@ -87,7 +87,7 @@ export async function runApp(opts: RunAppOptions, deps: RunAppDeps): Promise<voi
|
||||
await executeRun(opts, deps, meta, wsId)
|
||||
}
|
||||
catch (err) {
|
||||
if (err instanceof BaseError && err.httpStatus === 422) {
|
||||
if (err instanceof HttpClientError && err.httpStatus === 422) {
|
||||
await meta.invalidate(opts.appId)
|
||||
throw err.withHint('app metadata cache cleared — if the app was recently republished, run the command again')
|
||||
}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import type { HttpClientError } from '@/errors/base'
|
||||
import type { SseEvent } from '@/http/sse'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { collect, collectorFor, decodeStreamError, HitlPauseError } from './sse-collector'
|
||||
@ -130,7 +131,7 @@ describe('decodeStreamError', () => {
|
||||
const err = decodeStreamError(enc.encode(JSON.stringify(env)))
|
||||
expect(err.message).toBe(inner.args.description)
|
||||
expect(err.code).toBe('server_4xx_other')
|
||||
expect(err.httpStatus).toBe(400)
|
||||
expect((err as HttpClientError).httpStatus).toBe(400)
|
||||
})
|
||||
|
||||
it('unwraps openapi-v1 invoke-error: falls back to inner.message when no args.description', () => {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import type { BaseError } from '@/errors/base'
|
||||
import type { SseEvent } from '@/http/sse'
|
||||
import { newError } from '@/errors/base'
|
||||
import { HttpClientError, newError } from '@/errors/base'
|
||||
import { ErrorCode } from '@/errors/codes'
|
||||
import { RUN_MODES } from './handlers'
|
||||
|
||||
@ -173,9 +173,9 @@ export function decodeStreamError(data: Uint8Array): BaseError {
|
||||
const code = env.status !== undefined && env.status > 0 && env.status < 500
|
||||
? ErrorCode.Server4xxOther
|
||||
: ErrorCode.Server5xx
|
||||
let err = newError(code, message)
|
||||
const err = newError(code, message)
|
||||
if (env.status !== undefined && env.status > 0)
|
||||
err = err.withHttpStatus(env.status)
|
||||
return HttpClientError.from(err).withHttpStatus(env.status)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { BaseError, isBaseError, newError, unknownError } from './base'
|
||||
import { BaseError, HttpClientError, isBaseError, newError, unknownError } from './base'
|
||||
import { ErrorCode, ExitCode } from './codes'
|
||||
|
||||
describe('BaseError', () => {
|
||||
it('captures code, message, optional fields', () => {
|
||||
const err = new BaseError({
|
||||
const err = new HttpClientError({
|
||||
code: ErrorCode.AuthExpired,
|
||||
message: 'session expired',
|
||||
hint: 'run difyctl auth login',
|
||||
@ -30,7 +30,6 @@ describe('BaseError', () => {
|
||||
expect(newError(ErrorCode.AuthExpired, 'x').exit()).toBe(ExitCode.Auth)
|
||||
expect(newError(ErrorCode.UsageInvalidFlag, 'x').exit()).toBe(ExitCode.Usage)
|
||||
expect(newError(ErrorCode.VersionSkew, 'x').exit()).toBe(ExitCode.VersionCompat)
|
||||
expect(newError(ErrorCode.NetworkDns, 'x').exit()).toBe(ExitCode.Generic)
|
||||
})
|
||||
|
||||
it('toString without hint formats "<code>: <message>"', () => {
|
||||
@ -56,7 +55,7 @@ describe('BaseError', () => {
|
||||
|
||||
it('withHttpStatus + withRequest + wrap chain immutably', () => {
|
||||
const cause = new Error('underlying')
|
||||
const built = newError(ErrorCode.NetworkTimeout, 'timed out')
|
||||
const built = HttpClientError.from(newError(ErrorCode.NetworkConnection, 'timed out'))
|
||||
.withHttpStatus(504)
|
||||
.withRequest('POST', 'https://x/y')
|
||||
.wrap(cause)
|
||||
@ -68,7 +67,7 @@ describe('BaseError', () => {
|
||||
|
||||
it('wrap exposes cause via standard Error.cause property', () => {
|
||||
const cause = new Error('underlying failure')
|
||||
const wrapped = newError(ErrorCode.NetworkTimeout, 'timed out').wrap(cause)
|
||||
const wrapped = newError(ErrorCode.NetworkConnection, 'timed out').wrap(cause)
|
||||
expect(wrapped.cause).toBe(cause)
|
||||
})
|
||||
|
||||
@ -86,3 +85,56 @@ describe('BaseError', () => {
|
||||
expect(err.cause).toBe(cause)
|
||||
})
|
||||
})
|
||||
|
||||
describe('error envelope', () => {
|
||||
it('emits required fields only when minimal', () => {
|
||||
const err = newError(ErrorCode.Unknown, 'boom')
|
||||
expect(err.toEnvelope()).toEqual({
|
||||
error: { code: 'unknown', message: 'boom' },
|
||||
})
|
||||
})
|
||||
|
||||
it('includes hint / http_status / method / url when present', () => {
|
||||
const err = HttpClientError.from(newError(ErrorCode.NetworkConnection, 'timed out'))
|
||||
.withHint('check your network')
|
||||
.withHttpStatus(504)
|
||||
.withRequest('POST', 'https://api.dify.ai/v1/x')
|
||||
expect(err.toEnvelope()).toEqual({
|
||||
error: {
|
||||
code: 'network_connection',
|
||||
message: 'timed out',
|
||||
hint: 'check your network',
|
||||
http_status: 504,
|
||||
method: 'POST',
|
||||
url: 'https://api.dify.ai/v1/x',
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('renderEnvelope returns a single-line JSON string', () => {
|
||||
const err = newError(ErrorCode.AuthExpired, 'session expired')
|
||||
.withHint('run difyctl auth login')
|
||||
const out = JSON.stringify(err.toEnvelope())
|
||||
expect(out).toBe(
|
||||
'{"error":{"code":"auth_expired","message":"session expired","hint":"run difyctl auth login"}}',
|
||||
)
|
||||
expect(out).not.toContain('\n')
|
||||
})
|
||||
|
||||
it('renderEnvelope output round-trips through JSON.parse to an ErrorEnvelope shape', () => {
|
||||
const err = newError(ErrorCode.UsageInvalidFlag, 'bad flag').withHint('see --help')
|
||||
const parsed = JSON.parse(JSON.stringify(err.toEnvelope()))
|
||||
expect(parsed).toEqual({
|
||||
error: { code: 'usage_invalid_flag', message: 'bad flag', hint: 'see --help' },
|
||||
})
|
||||
})
|
||||
|
||||
it('omits undefined optional fields entirely (no `hint: null`)', () => {
|
||||
const err = newError(ErrorCode.Server5xx, 'upstream broke')
|
||||
const envelope = err.toEnvelope()
|
||||
expect(envelope.error).not.toHaveProperty('hint')
|
||||
expect(envelope.error).not.toHaveProperty('http_status')
|
||||
expect(envelope.error).not.toHaveProperty('method')
|
||||
expect(envelope.error).not.toHaveProperty('url')
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,31 +1,24 @@
|
||||
import type { ErrorCodeValue, ExitCodeValue } from './codes'
|
||||
import type { ErrorEnvelope, PrintableError } from './format'
|
||||
import { ErrorCode, exitFor } from './codes'
|
||||
|
||||
export type BaseErrorOptions = {
|
||||
readonly code: ErrorCodeValue
|
||||
readonly message: string
|
||||
readonly hint?: string
|
||||
readonly httpStatus?: number
|
||||
readonly method?: string
|
||||
readonly url?: string
|
||||
readonly cause?: unknown
|
||||
}
|
||||
|
||||
export class BaseError extends Error {
|
||||
export class BaseError extends Error implements PrintableError {
|
||||
readonly code: ErrorCodeValue
|
||||
readonly hint?: string
|
||||
readonly httpStatus?: number
|
||||
readonly method?: string
|
||||
readonly url?: string
|
||||
|
||||
constructor(opts: BaseErrorOptions) {
|
||||
super(opts.message, opts.cause === undefined ? undefined : { cause: opts.cause })
|
||||
this.name = 'BaseError'
|
||||
this.code = opts.code
|
||||
this.hint = opts.hint
|
||||
this.httpStatus = opts.httpStatus
|
||||
this.method = opts.method
|
||||
this.url = opts.url
|
||||
|
||||
Object.setPrototypeOf(this, new.target.prototype)
|
||||
}
|
||||
|
||||
@ -39,30 +32,31 @@ export class BaseError extends Error {
|
||||
: `${this.code}: ${this.message}`
|
||||
}
|
||||
|
||||
withHint(hint: string): BaseError {
|
||||
return new BaseError({ ...this.snapshot(), hint })
|
||||
toEnvelope(): ErrorEnvelope {
|
||||
const payload: ErrorEnvelope['error'] = {
|
||||
code: this.code,
|
||||
message: this.message,
|
||||
}
|
||||
if (this.hint !== undefined)
|
||||
payload.hint = this.hint
|
||||
return { error: payload }
|
||||
}
|
||||
|
||||
withHttpStatus(httpStatus: number): BaseError {
|
||||
return new BaseError({ ...this.snapshot(), httpStatus })
|
||||
withHint<T extends BaseError>(this: T, hint: string): T {
|
||||
const Ctor = this.constructor as new (opts: BaseErrorOptions) => T
|
||||
return new Ctor({ ...this.snapshot(), hint })
|
||||
}
|
||||
|
||||
withRequest(method: string, url: string): BaseError {
|
||||
return new BaseError({ ...this.snapshot(), method, url })
|
||||
wrap<T extends BaseError>(this: T, cause: unknown): T {
|
||||
const Ctor = this.constructor as new (opts: BaseErrorOptions) => T
|
||||
return new Ctor({ ...this.snapshot(), cause })
|
||||
}
|
||||
|
||||
wrap(cause: unknown): BaseError {
|
||||
return new BaseError({ ...this.snapshot(), cause })
|
||||
}
|
||||
|
||||
private snapshot(): BaseErrorOptions {
|
||||
protected snapshot(): BaseErrorOptions {
|
||||
return {
|
||||
code: this.code,
|
||||
message: this.message,
|
||||
hint: this.hint,
|
||||
httpStatus: this.httpStatus,
|
||||
method: this.method,
|
||||
url: this.url,
|
||||
cause: this.cause,
|
||||
}
|
||||
}
|
||||
@ -76,6 +70,79 @@ export function isBaseError(value: unknown): value is BaseError {
|
||||
return value instanceof BaseError
|
||||
}
|
||||
|
||||
export function isHttpClientError(value: unknown): value is HttpClientError {
|
||||
return value instanceof HttpClientError
|
||||
}
|
||||
|
||||
export function unknownError(message: string, cause?: unknown): BaseError {
|
||||
return new BaseError({ code: ErrorCode.Unknown, message, cause })
|
||||
}
|
||||
|
||||
type HttpClientErrorOptions = BaseErrorOptions & {
|
||||
readonly httpStatus?: number
|
||||
readonly method?: string
|
||||
readonly url?: string
|
||||
readonly rawResponse?: string
|
||||
}
|
||||
|
||||
export class HttpClientError extends BaseError {
|
||||
readonly httpStatus?: number
|
||||
readonly method?: string
|
||||
readonly url?: string
|
||||
readonly rawResponse?: string
|
||||
|
||||
constructor(opts: HttpClientErrorOptions) {
|
||||
super(opts)
|
||||
this.httpStatus = opts.httpStatus
|
||||
this.method = opts.method
|
||||
this.url = opts.url
|
||||
this.rawResponse = opts.rawResponse
|
||||
}
|
||||
|
||||
override toEnvelope(): ErrorEnvelope {
|
||||
const envelope = super.toEnvelope()
|
||||
if (this.httpStatus !== undefined)
|
||||
envelope.error.http_status = this.httpStatus
|
||||
if (this.method !== undefined)
|
||||
envelope.error.method = this.method
|
||||
if (this.url !== undefined)
|
||||
envelope.error.url = this.url
|
||||
if (this.rawResponse !== undefined)
|
||||
envelope.error.raw_response = this.rawResponse
|
||||
return envelope
|
||||
}
|
||||
|
||||
protected override snapshot(): HttpClientErrorOptions {
|
||||
return {
|
||||
...super.snapshot(),
|
||||
httpStatus: this.httpStatus,
|
||||
method: this.method,
|
||||
url: this.url,
|
||||
rawResponse: this.rawResponse,
|
||||
}
|
||||
}
|
||||
|
||||
public static from(error: BaseError): HttpClientError {
|
||||
return new HttpClientError({
|
||||
code: error.code,
|
||||
message: error.message,
|
||||
hint: error.hint,
|
||||
cause: error.cause,
|
||||
})
|
||||
}
|
||||
|
||||
withHttpStatus(httpStatus: number): HttpClientError {
|
||||
return new HttpClientError({ ...this.snapshot(), httpStatus })
|
||||
}
|
||||
|
||||
withRequest(method: string, url: string): HttpClientError {
|
||||
return new HttpClientError({ ...this.snapshot(), method, url })
|
||||
}
|
||||
|
||||
withRawResponse(rawResponse: string): HttpClientError {
|
||||
if (!rawResponse) {
|
||||
return this
|
||||
}
|
||||
return new HttpClientError({ ...this.snapshot(), rawResponse })
|
||||
}
|
||||
}
|
||||
|
||||
@ -42,8 +42,6 @@ describe('error codes', () => {
|
||||
[ErrorCode.UsageMissingArg, ExitCode.Usage],
|
||||
[ErrorCode.ConfigInvalidKey, ExitCode.Usage],
|
||||
[ErrorCode.ConfigInvalidValue, ExitCode.Usage],
|
||||
[ErrorCode.NetworkTimeout, ExitCode.Generic],
|
||||
[ErrorCode.NetworkDns, ExitCode.Generic],
|
||||
[ErrorCode.Server5xx, ExitCode.Generic],
|
||||
[ErrorCode.Server4xxOther, ExitCode.Generic],
|
||||
[ErrorCode.ClientError, ExitCode.Generic],
|
||||
|
||||
@ -11,8 +11,7 @@ export const ErrorCode = {
|
||||
UsageMissingArg: 'usage_missing_arg',
|
||||
ConfigInvalidKey: 'config_invalid_key',
|
||||
ConfigInvalidValue: 'config_invalid_value',
|
||||
NetworkTimeout: 'network_timeout',
|
||||
NetworkDns: 'network_dns',
|
||||
NetworkConnection: 'network_connection',
|
||||
Server5xx: 'server_5xx',
|
||||
Server4xxOther: 'server_4xx_other',
|
||||
ClientError: 'client_error',
|
||||
@ -45,8 +44,7 @@ const CODE_TO_EXIT: Readonly<Record<ErrorCodeValue, ExitCodeValue>> = {
|
||||
usage_missing_arg: ExitCode.Usage,
|
||||
config_invalid_key: ExitCode.Usage,
|
||||
config_invalid_value: ExitCode.Usage,
|
||||
network_timeout: ExitCode.Generic,
|
||||
network_dns: ExitCode.Generic,
|
||||
network_connection: ExitCode.Generic,
|
||||
server_5xx: ExitCode.Generic,
|
||||
server_4xx_other: ExitCode.Generic,
|
||||
client_error: ExitCode.Generic,
|
||||
|
||||
@ -1,57 +0,0 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { newError } from './base'
|
||||
import { ErrorCode } from './codes'
|
||||
import { renderEnvelope, toEnvelope } from './envelope'
|
||||
|
||||
describe('error envelope', () => {
|
||||
it('emits required fields only when minimal', () => {
|
||||
const err = newError(ErrorCode.Unknown, 'boom')
|
||||
expect(toEnvelope(err)).toEqual({
|
||||
error: { code: 'unknown', message: 'boom' },
|
||||
})
|
||||
})
|
||||
|
||||
it('includes hint / http_status / method / url when present', () => {
|
||||
const err = newError(ErrorCode.NetworkTimeout, 'timed out')
|
||||
.withHint('check your network')
|
||||
.withHttpStatus(504)
|
||||
.withRequest('POST', 'https://api.dify.ai/v1/x')
|
||||
expect(toEnvelope(err)).toEqual({
|
||||
error: {
|
||||
code: 'network_timeout',
|
||||
message: 'timed out',
|
||||
hint: 'check your network',
|
||||
http_status: 504,
|
||||
method: 'POST',
|
||||
url: 'https://api.dify.ai/v1/x',
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('renderEnvelope returns a single-line JSON string', () => {
|
||||
const err = newError(ErrorCode.AuthExpired, 'session expired')
|
||||
.withHint('run difyctl auth login')
|
||||
const out = renderEnvelope(err)
|
||||
expect(out).toBe(
|
||||
'{"error":{"code":"auth_expired","message":"session expired","hint":"run difyctl auth login"}}',
|
||||
)
|
||||
expect(out).not.toContain('\n')
|
||||
})
|
||||
|
||||
it('renderEnvelope output round-trips through JSON.parse to an ErrorEnvelope shape', () => {
|
||||
const err = newError(ErrorCode.UsageInvalidFlag, 'bad flag').withHint('see --help')
|
||||
const parsed = JSON.parse(renderEnvelope(err))
|
||||
expect(parsed).toEqual({
|
||||
error: { code: 'usage_invalid_flag', message: 'bad flag', hint: 'see --help' },
|
||||
})
|
||||
})
|
||||
|
||||
it('omits undefined optional fields entirely (no `hint: null`)', () => {
|
||||
const err = newError(ErrorCode.Server5xx, 'upstream broke')
|
||||
const envelope = toEnvelope(err)
|
||||
expect(envelope.error).not.toHaveProperty('hint')
|
||||
expect(envelope.error).not.toHaveProperty('http_status')
|
||||
expect(envelope.error).not.toHaveProperty('method')
|
||||
expect(envelope.error).not.toHaveProperty('url')
|
||||
})
|
||||
})
|
||||
@ -1,32 +0,0 @@
|
||||
import type { BaseError } from './base'
|
||||
|
||||
export type ErrorEnvelope = {
|
||||
error: {
|
||||
code: string
|
||||
message: string
|
||||
hint?: string
|
||||
http_status?: number
|
||||
method?: string
|
||||
url?: string
|
||||
}
|
||||
}
|
||||
|
||||
export function toEnvelope(err: BaseError): ErrorEnvelope {
|
||||
const payload: ErrorEnvelope['error'] = {
|
||||
code: err.code,
|
||||
message: err.message,
|
||||
}
|
||||
if (err.hint !== undefined)
|
||||
payload.hint = err.hint
|
||||
if (err.httpStatus !== undefined)
|
||||
payload.http_status = err.httpStatus
|
||||
if (err.method !== undefined)
|
||||
payload.method = err.method
|
||||
if (err.url !== undefined)
|
||||
payload.url = err.url
|
||||
return { error: payload }
|
||||
}
|
||||
|
||||
export function renderEnvelope(err: BaseError): string {
|
||||
return JSON.stringify(toEnvelope(err))
|
||||
}
|
||||
@ -1,26 +1,58 @@
|
||||
import type { BaseError } from './base'
|
||||
import { isVerbose } from '@/framework/context'
|
||||
import { redactBearer } from '@/http/sanitize'
|
||||
import { colorEnabled, colorScheme } from '@/sys/io/color'
|
||||
import { renderEnvelope } from './envelope'
|
||||
|
||||
export type FormatErrorOptions = {
|
||||
readonly format?: string
|
||||
readonly isErrTTY?: boolean
|
||||
}
|
||||
|
||||
export function formatErrorForCli(err: BaseError, opts: FormatErrorOptions = {}): string {
|
||||
if (opts.format === 'json')
|
||||
return renderEnvelope(err)
|
||||
return humanError(err, opts.isErrTTY ?? false)
|
||||
export type ErrorEnvelope = {
|
||||
error: {
|
||||
code: string
|
||||
message: string
|
||||
hint?: string
|
||||
http_status?: number
|
||||
method?: string
|
||||
url?: string
|
||||
raw_response?: string
|
||||
}
|
||||
}
|
||||
|
||||
function humanError(err: BaseError, isErrTTY: boolean): string {
|
||||
export type PrintableError = {
|
||||
toEnvelope: () => ErrorEnvelope
|
||||
}
|
||||
|
||||
export function formatErrorForCli(err: PrintableError, opts: FormatErrorOptions = {}): string {
|
||||
const env = err.toEnvelope()
|
||||
if (opts.format === 'json')
|
||||
return renderEnvelope(env)
|
||||
return renderHuman(env, opts.isErrTTY ?? false)
|
||||
}
|
||||
|
||||
function renderEnvelope(env: ErrorEnvelope): string {
|
||||
const raw = env.error.raw_response
|
||||
if (raw === undefined)
|
||||
return JSON.stringify(env)
|
||||
if (!isVerbose()) {
|
||||
delete env.error.raw_response
|
||||
return JSON.stringify(env)
|
||||
}
|
||||
env.error.raw_response = redactBearer(raw)
|
||||
return JSON.stringify(env)
|
||||
}
|
||||
|
||||
function renderHuman(env: ErrorEnvelope, isErrTTY: boolean): string {
|
||||
const cs = colorScheme(colorEnabled(isErrTTY))
|
||||
const lines: string[] = [`${err.code}: ${err.message}`]
|
||||
if (err.hint !== undefined)
|
||||
lines.push(`${cs.magenta('hint:')} ${cs.cyan(err.hint)}`)
|
||||
if (err.method !== undefined && err.url !== undefined)
|
||||
lines.push(`request: ${err.method} ${err.url}`)
|
||||
if (err.httpStatus !== undefined)
|
||||
lines.push(`http_status: ${err.httpStatus}`)
|
||||
const e = env.error
|
||||
const lines: string[] = [`${e.code}: ${e.message}`]
|
||||
if (e.hint !== undefined)
|
||||
lines.push(`${cs.magenta('hint:')} ${cs.cyan(e.hint)}`)
|
||||
if (e.method !== undefined && e.url !== undefined)
|
||||
lines.push(`request: ${e.method} ${e.url}`)
|
||||
if (e.http_status !== undefined)
|
||||
lines.push(`http_status: ${e.http_status}`)
|
||||
if (isVerbose() && e.raw_response)
|
||||
lines.push(`raw_response: ${redactBearer(e.raw_response)}`)
|
||||
return lines.join('\n')
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import type { CommandOutput } from './output'
|
||||
import type { ArgDefinition, FlagDefinition, ICommand, InferArgs, InferFlags, OptionalArgValueType } from './types'
|
||||
import { parseArgv } from './flags'
|
||||
import { setVerbose } from './context'
|
||||
import { hasBooleanFlag, parseArgv, VERBOSE_CHAR, VERBOSE_FLAG } from './flags'
|
||||
|
||||
export type CommandConstructor = {
|
||||
new(): Command
|
||||
@ -28,11 +29,16 @@ type ParseResult<C extends CommandConstructor> = {
|
||||
export abstract class Command implements ICommand {
|
||||
static description?: string
|
||||
static flags: Record<string, FlagDefinition<OptionalArgValueType>> = {}
|
||||
|
||||
static args: Record<string, ArgDefinition<string | undefined>> = {}
|
||||
static examples: string[] = []
|
||||
|
||||
abstract run(argv: string[]): Promise<CommandOutput | void>
|
||||
|
||||
processGlobalFlags(argv: readonly string[]): void {
|
||||
setVerbose(hasBooleanFlag(argv, VERBOSE_FLAG, VERBOSE_CHAR))
|
||||
}
|
||||
|
||||
protected parse<C extends CommandConstructor>(ctor: C, argv: string[]): ParseResult<C> {
|
||||
const meta = {
|
||||
flags: ctor.flags ?? {},
|
||||
|
||||
15
cli/src/framework/context.ts
Normal file
15
cli/src/framework/context.ts
Normal file
@ -0,0 +1,15 @@
|
||||
type CommandContext = {
|
||||
verbose: boolean
|
||||
}
|
||||
|
||||
const commandContext: CommandContext = {
|
||||
verbose: false,
|
||||
}
|
||||
|
||||
export function setVerbose(verbose: boolean): void {
|
||||
commandContext.verbose = verbose
|
||||
}
|
||||
|
||||
export function isVerbose(): boolean {
|
||||
return commandContext.verbose
|
||||
}
|
||||
@ -1,6 +1,24 @@
|
||||
import type { ArgDefinition, CommandMeta, FlagDefinition, ParsedArgs, ParsedFlags } from './types'
|
||||
import { UnsupportedArgValueError } from './errors'
|
||||
|
||||
export const VERBOSE_FLAG = 'verbose'
|
||||
export const VERBOSE_CHAR = 'v'
|
||||
|
||||
export const Flags = {
|
||||
string: stringFlag,
|
||||
stringArray: stringRepeatedFlag,
|
||||
boolean: booleanFlag,
|
||||
integer: integerFlag,
|
||||
outputFormat: outputFormatFlag,
|
||||
}
|
||||
|
||||
const GLOBAL_FLAGS: Record<string, FlagDefinition> = {
|
||||
[VERBOSE_FLAG]: Flags.boolean({
|
||||
char: VERBOSE_CHAR,
|
||||
description: 'enable verbose output',
|
||||
}),
|
||||
}
|
||||
|
||||
function stringFlag<const Opts extends {
|
||||
description: string
|
||||
char?: string
|
||||
@ -48,14 +66,6 @@ function integerFlag<const Opts extends { description: string, char?: string, de
|
||||
return { type: 'integer', ...opts } as FlagDefinition<Opts extends { default: number } ? number : number | undefined>
|
||||
}
|
||||
|
||||
export const Flags = {
|
||||
string: stringFlag,
|
||||
stringArray: stringRepeatedFlag,
|
||||
boolean: booleanFlag,
|
||||
integer: integerFlag,
|
||||
outputFormat: outputFormatFlag,
|
||||
}
|
||||
|
||||
function stringArg<const Opts extends { description: string, required?: boolean }>(
|
||||
opts: Opts,
|
||||
): ArgDefinition<Opts extends { required: true } ? string : string | undefined> {
|
||||
@ -99,8 +109,8 @@ function accumulateFlagValue(flags: ParsedFlags, name: string, value: string | b
|
||||
}
|
||||
}
|
||||
|
||||
function resolveByChar(char: string, meta: CommandMeta): [name: string, def: FlagDefinition] | undefined {
|
||||
for (const [name, def] of Object.entries(meta.flags)) {
|
||||
function resolveByChar(char: string, flags: Record<string, FlagDefinition>): [name: string, def: FlagDefinition] | undefined {
|
||||
for (const [name, def] of Object.entries(flags)) {
|
||||
if (def.char === char)
|
||||
return [name, def]
|
||||
}
|
||||
@ -115,12 +125,12 @@ function validateFlagOptions(name: string, raw: string, def: FlagDefinition): vo
|
||||
|
||||
type ResolvedFlag = { name: string, def: FlagDefinition, label: string, inlineRaw: string | undefined }
|
||||
|
||||
function resolveToken(token: string, meta: CommandMeta): ResolvedFlag | null {
|
||||
function resolveToken(token: string, flags: Record<string, FlagDefinition>): ResolvedFlag | null {
|
||||
if (token.startsWith('--')) {
|
||||
const eqIdx = token.indexOf('=')
|
||||
const name = eqIdx !== -1 ? token.slice(2, eqIdx) : token.slice(2)
|
||||
const inlineRaw = eqIdx !== -1 ? token.slice(eqIdx + 1) : undefined
|
||||
const def = meta.flags[name]
|
||||
const def = flags[name]
|
||||
if (!def)
|
||||
throw new Error(`unknown flag: --${name}`)
|
||||
return { name, def, label: `--${name}`, inlineRaw }
|
||||
@ -128,7 +138,7 @@ function resolveToken(token: string, meta: CommandMeta): ResolvedFlag | null {
|
||||
|
||||
if (token.length === 2 && token[1] !== undefined) {
|
||||
const char = token[1]
|
||||
const resolved = resolveByChar(char, meta)
|
||||
const resolved = resolveByChar(char, flags)
|
||||
if (!resolved)
|
||||
throw new Error(`unknown flag: -${char}`)
|
||||
const [name, def] = resolved
|
||||
@ -138,6 +148,21 @@ function resolveToken(token: string, meta: CommandMeta): ResolvedFlag | null {
|
||||
return null
|
||||
}
|
||||
|
||||
// Scans argv for a boolean flag without throwing on unknown tokens, so it is safe
|
||||
// to call before the command-specific flag set is known (e.g. global flags).
|
||||
export function hasBooleanFlag(argv: readonly string[], name: string, char?: string): boolean {
|
||||
for (const token of argv) {
|
||||
if (token === '--')
|
||||
break
|
||||
if (token === `--${name}` || token === `--${name}=true` || token === `--${name}=1`)
|
||||
return true
|
||||
if (char !== undefined && token === `-${char}`)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
export function parseArgv(argv: readonly string[], meta: CommandMeta): { args: ParsedArgs, flags: ParsedFlags } {
|
||||
const flags: ParsedFlags = {}
|
||||
const positional: string[] = []
|
||||
@ -159,7 +184,10 @@ export function parseArgv(argv: readonly string[], meta: CommandMeta): { args: P
|
||||
continue
|
||||
}
|
||||
|
||||
const resolved = resolveToken(token, meta)
|
||||
const resolved = resolveToken(token, {
|
||||
...meta.flags,
|
||||
...GLOBAL_FLAGS, // pass global flags to prevent unknown flag error
|
||||
})
|
||||
if (!resolved) {
|
||||
positional.push(token)
|
||||
continue
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import type { CommandConstructor } from './command'
|
||||
import type { CommandTree } from './registry'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { BaseError, newError } from '@/errors/base'
|
||||
import { BaseError, HttpClientError, newError } from '@/errors/base'
|
||||
import { ErrorCode, ExitCode } from '@/errors/codes'
|
||||
import { Command } from './command'
|
||||
import { run, sniffOutputFormat } from './run'
|
||||
@ -171,7 +171,7 @@ describe('run() catch routing', () => {
|
||||
it('routes Server5xx error with http_status line and generic exit', async () => {
|
||||
class Throwing extends Command {
|
||||
async run(_argv: string[]) {
|
||||
throw newError(ErrorCode.Server5xx, 'upstream boom').withHttpStatus(502)
|
||||
throw HttpClientError.from(newError(ErrorCode.Server5xx, 'upstream boom')).withHttpStatus(502)
|
||||
}
|
||||
}
|
||||
const result = await captureRun(makeTree(Throwing), ['cmd'])
|
||||
@ -182,7 +182,7 @@ describe('run() catch routing', () => {
|
||||
it('renders request line and http_status when both are present', async () => {
|
||||
class Throwing extends Command {
|
||||
async run(_argv: string[]) {
|
||||
throw newError(ErrorCode.Server5xx, 'upstream boom')
|
||||
throw HttpClientError.from(newError(ErrorCode.Server5xx, 'upstream boom'))
|
||||
.withRequest('GET', 'https://api.dify.ai/v1/me')
|
||||
.withHttpStatus(502)
|
||||
}
|
||||
@ -197,7 +197,7 @@ describe('run() catch routing', () => {
|
||||
it('serializes method and url in JSON envelope', async () => {
|
||||
class Throwing extends Command {
|
||||
async run(_argv: string[]) {
|
||||
throw newError(ErrorCode.Server4xxOther, 'not found')
|
||||
throw HttpClientError.from(newError(ErrorCode.Server4xxOther, 'not found'))
|
||||
.withRequest('GET', 'https://api.dify.ai/v1/apps/x')
|
||||
.withHttpStatus(404)
|
||||
}
|
||||
|
||||
@ -45,7 +45,10 @@ export async function run(tree: CommandTree, argv: string[]): Promise<void> {
|
||||
if (typeof Ctor.deprecated === 'string' && Ctor.deprecated.length > 0)
|
||||
process.stderr.write(`deprecated: ${Ctor.deprecated}\n`)
|
||||
const cmd = new Ctor()
|
||||
const output = await cmd.run(argv.slice(resolved.path.length))
|
||||
const commandArgv = argv.slice(resolved.path.length)
|
||||
cmd.processGlobalFlags(commandArgv)
|
||||
|
||||
const output = await cmd.run(commandArgv)
|
||||
if (output !== undefined)
|
||||
process.stdout.write(stringifyOutput(output))
|
||||
}
|
||||
|
||||
@ -3,7 +3,7 @@ import type { AddressInfo } from 'node:net'
|
||||
import * as http from 'node:http'
|
||||
import { startMock } from '@test/fixtures/dify-mock/server'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { isBaseError } from '@/errors/base'
|
||||
import { isBaseError, isHttpClientError } from '@/errors/base'
|
||||
import { ErrorCode } from '@/errors/codes'
|
||||
import { openAPIBase } from '@/util/host'
|
||||
import { createHttpClient } from './client.js'
|
||||
@ -116,8 +116,8 @@ describe('http client', () => {
|
||||
await client.get('workspaces')
|
||||
}
|
||||
catch (err) { caught = err }
|
||||
expect(isBaseError(caught)).toBe(true)
|
||||
if (isBaseError(caught)) {
|
||||
expect(isHttpClientError(caught)).toBe(true)
|
||||
if (isHttpClientError(caught)) {
|
||||
expect(caught.code).toBe(ErrorCode.AuthExpired)
|
||||
expect(caught.httpStatus).toBe(401)
|
||||
expect(caught.method).toBe('GET')
|
||||
@ -138,8 +138,8 @@ describe('http client', () => {
|
||||
await client.get('workspaces')
|
||||
}
|
||||
catch (err) { caught = err }
|
||||
expect(isBaseError(caught)).toBe(true)
|
||||
if (isBaseError(caught)) {
|
||||
expect(isHttpClientError(caught)).toBe(true)
|
||||
if (isHttpClientError(caught)) {
|
||||
expect(caught.code).toBe(ErrorCode.Server5xx)
|
||||
expect(caught.httpStatus).toBe(503)
|
||||
}
|
||||
@ -187,8 +187,8 @@ describe('http client', () => {
|
||||
await client.get('apps/nope/describe')
|
||||
}
|
||||
catch (err) { caught = err }
|
||||
expect(isBaseError(caught)).toBe(true)
|
||||
if (isBaseError(caught))
|
||||
expect(isHttpClientError(caught)).toBe(true)
|
||||
if (isHttpClientError(caught))
|
||||
expect(caught.code).toBe(ErrorCode.Server4xxOther)
|
||||
})
|
||||
|
||||
@ -205,8 +205,8 @@ describe('http client', () => {
|
||||
await client.get('workspaces')
|
||||
}
|
||||
catch (err) { caught = err }
|
||||
expect(isBaseError(caught)).toBe(true)
|
||||
if (isBaseError(caught))
|
||||
expect(isHttpClientError(caught)).toBe(true)
|
||||
if (isHttpClientError(caught))
|
||||
expect(caught.httpStatus).toBe(429)
|
||||
})
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ import type {
|
||||
RequestOptions,
|
||||
ResolvedOptions,
|
||||
} from './types.js'
|
||||
import { isVerbose } from '@/framework/context'
|
||||
import { userAgent as defaultUserAgent } from '@/version/info'
|
||||
import { buildBody } from './body.js'
|
||||
import { classifyResponse } from './error-mapper.js'
|
||||
@ -133,11 +134,11 @@ async function dispatch(state: ClientState, path: string, opts: RequestOptions,
|
||||
|
||||
await runHooks(state.hooks.onRequest, ctx)
|
||||
|
||||
// `dispatcher` is an undici extension to RequestInit, not in @types/node's fetch
|
||||
// signature — hence the local type. Carries proxy routing when a proxy env var is set.
|
||||
const init: RequestInit & { dispatcher?: unknown } = { signal }
|
||||
const init: RequestInit & { dispatcher?: unknown, verbose?: boolean } = { signal }
|
||||
if (state.dispatcher !== undefined)
|
||||
init.dispatcher = state.dispatcher
|
||||
if (isVerbose())
|
||||
init.verbose = true
|
||||
|
||||
try {
|
||||
ctx.response = await fetch(ctx.request, init)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import type { BaseError } from '@/errors/base'
|
||||
import { newError } from '@/errors/base'
|
||||
import { BaseError, HttpClientError, newError } from '@/errors/base'
|
||||
import { ErrorCode } from '@/errors/codes'
|
||||
import { redactBearer } from './sanitize'
|
||||
|
||||
@ -32,54 +31,50 @@ async function readBody(response: Response): Promise<{ raw: string, parsed?: Wir
|
||||
}
|
||||
|
||||
export async function classifyResponse(request: Request, response: Response): Promise<BaseError> {
|
||||
const { parsed } = await readBody(response.clone())
|
||||
const { parsed, raw } = await readBody(response.clone())
|
||||
const wire: WireFields = parsed?.error ?? parsed ?? {}
|
||||
const status = response.status
|
||||
const url = redactBearer(response.url || request.url)
|
||||
const method = request.method
|
||||
|
||||
if (status === 401) {
|
||||
return newError(
|
||||
return HttpClientError.from(newError(
|
||||
ErrorCode.AuthExpired,
|
||||
wire.message ?? 'session expired or revoked',
|
||||
)
|
||||
))
|
||||
.withHint(wire.hint ?? 'run \'difyctl auth login\' to sign in again')
|
||||
.withHttpStatus(status)
|
||||
.withRequest(method, url)
|
||||
}
|
||||
|
||||
if (status >= 500) {
|
||||
return newError(
|
||||
return HttpClientError.from(newError(
|
||||
ErrorCode.Server5xx,
|
||||
wire.message ?? `server error (HTTP ${status})`,
|
||||
)
|
||||
))
|
||||
.withHttpStatus(status)
|
||||
.withRequest(method, url)
|
||||
.withRawResponse(raw)
|
||||
}
|
||||
|
||||
const err = newError(
|
||||
const err = HttpClientError.from(newError(
|
||||
ErrorCode.Server4xxOther,
|
||||
wire.message ?? `request failed (HTTP ${status})`,
|
||||
)
|
||||
))
|
||||
.withHttpStatus(status)
|
||||
.withRequest(method, url)
|
||||
.withRawResponse(raw)
|
||||
return wire.hint !== undefined ? err.withHint(wire.hint) : err
|
||||
}
|
||||
|
||||
export function classifyTransportError(err: unknown): BaseError {
|
||||
const message = err instanceof Error ? err.message : String(err)
|
||||
const sanitized = redactBearer(message)
|
||||
|
||||
if (err instanceof Error && err.name === 'TimeoutError')
|
||||
return newError(ErrorCode.NetworkTimeout, 'request timed out').wrap(err)
|
||||
if (err instanceof Error && err.name === 'AbortError')
|
||||
return newError(ErrorCode.NetworkTimeout, 'request aborted').wrap(err)
|
||||
if (sanitized.toLowerCase().includes('econnrefused'))
|
||||
return newError(ErrorCode.NetworkDns, 'connection refused').wrap(err)
|
||||
if (sanitized.toLowerCase().includes('enotfound'))
|
||||
return newError(ErrorCode.NetworkDns, 'host lookup failed').wrap(err)
|
||||
if (sanitized.toLowerCase().includes('etimedout'))
|
||||
return newError(ErrorCode.NetworkTimeout, 'connection timed out').wrap(err)
|
||||
|
||||
return newError(ErrorCode.Unknown, sanitized).wrap(err)
|
||||
if (err instanceof BaseError) {
|
||||
return err
|
||||
}
|
||||
if (!(err instanceof Error)) {
|
||||
return newError(ErrorCode.Unknown, String(err)).wrap(err)
|
||||
}
|
||||
const sanitized = redactBearer(err.message)
|
||||
// there isn't a practical way to classify network errors reliably
|
||||
return newError(ErrorCode.NetworkConnection, sanitized).wrap(err)
|
||||
}
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
{
|
||||
"extends": "@dify/tsconfig/node.json",
|
||||
"compilerOptions": {
|
||||
"rootDir": "src",
|
||||
"moduleResolution": "bundler",
|
||||
"paths": {
|
||||
"@/*": [
|
||||
@ -12,12 +11,10 @@
|
||||
]
|
||||
},
|
||||
"types": ["node"],
|
||||
"declaration": true,
|
||||
"declarationMap": true,
|
||||
"noEmit": false,
|
||||
"noEmit": true, // we already have bundlers to handle this.
|
||||
"outDir": "dist",
|
||||
"sourceMap": true
|
||||
},
|
||||
"include": ["src/**/*.ts"],
|
||||
"exclude": ["dist", "test", "node_modules", "**/*.test.ts"]
|
||||
"include": ["src/**/*.ts", "test/**/*.ts"], // tests must be included for typechecking
|
||||
"exclude": ["node_modules", "dist"]
|
||||
}
|
||||
|
||||
@ -256,7 +256,9 @@ class DifyShellRuntimeState(BaseModel):
|
||||
raise ValueError("workspace_cwd requires a matching session_id.")
|
||||
expected_workspace = _workspace_cwd(self.session_id)
|
||||
if self.workspace_cwd != expected_workspace:
|
||||
raise ValueError(f"workspace_cwd must equal {expected_workspace!r} for session_id {self.session_id!r}.")
|
||||
raise ValueError(
|
||||
f"workspace_cwd must equal {expected_workspace!r} for session_id {self.session_id!r}."
|
||||
)
|
||||
unknown_offset_job_ids = set(self.job_offsets) - set(self.job_ids)
|
||||
if unknown_offset_job_ids:
|
||||
names = ", ".join(sorted(unknown_offset_job_ids))
|
||||
@ -692,12 +694,12 @@ def _workspace_mkdir_script(*, session_id: str) -> str:
|
||||
of silently reusing another session's workspace.
|
||||
"""
|
||||
safe_session_id = _validated_session_id(session_id)
|
||||
workspace_dir = f"$HOME/workspace/{safe_session_id}"
|
||||
workspace_dir = f'$HOME/workspace/{safe_session_id}'
|
||||
return (
|
||||
'mkdir -p "$HOME/workspace"; '
|
||||
f'if mkdir "{workspace_dir}"; then exit 0; fi; '
|
||||
f'if [ -e "{workspace_dir}" ]; then exit {_WORKSPACE_COLLISION_EXIT_CODE}; fi; '
|
||||
"exit 1"
|
||||
'exit 1'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -277,7 +277,6 @@ def test_shell_layer_suspend_and_resume_reuse_state_with_fresh_clients() -> None
|
||||
return next(clients)
|
||||
|
||||
compositor = Compositor([LayerNode("shell", _shell_provider(client_factory=factory))])
|
||||
|
||||
async def scenario() -> None:
|
||||
async with compositor.enter(configs={"shell": DifyShellLayerConfig()}) as run:
|
||||
shell_layer = run.get_layer("shell", DifyShellLayer)
|
||||
@ -343,10 +342,7 @@ def test_shell_layer_delete_removes_workspace_then_force_deletes_tracked_jobs_an
|
||||
|
||||
assert client.events[:2] == [("run", 'rm -rf -- "$HOME/workspace/abc12ff"'), ("wait", "cleanup-job")]
|
||||
assert {call.job_id for call in client.delete_calls} == {"user-job", "mkdir-job", "cleanup-job"}
|
||||
assert all(
|
||||
client.events.index(("delete", call.job_id)) > client.events.index(("wait", "cleanup-job"))
|
||||
for call in client.delete_calls
|
||||
)
|
||||
assert all(client.events.index(("delete", call.job_id)) > client.events.index(("wait", "cleanup-job")) for call in client.delete_calls)
|
||||
assert all(call.force is True for call in client.delete_calls)
|
||||
assert layer.runtime_state.job_ids == []
|
||||
assert layer.runtime_state.job_offsets == {}
|
||||
|
||||
@ -27,9 +27,7 @@ def test_default_layer_providers_register_shell_layer_with_configured_token_fact
|
||||
|
||||
return factory
|
||||
|
||||
monkeypatch.setattr(
|
||||
compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory
|
||||
)
|
||||
monkeypatch.setattr(compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory)
|
||||
|
||||
providers = create_default_layer_providers(
|
||||
shellctl_entrypoint="http://shellctl.example",
|
||||
@ -58,9 +56,7 @@ def test_default_layer_providers_keep_empty_shellctl_token_by_default(
|
||||
|
||||
return factory
|
||||
|
||||
monkeypatch.setattr(
|
||||
compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory
|
||||
)
|
||||
monkeypatch.setattr(compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory)
|
||||
|
||||
providers = create_default_layer_providers(shellctl_entrypoint="http://shellctl.example")
|
||||
shell_provider = next(provider for provider in providers if provider.type_id == DIFY_SHELL_LAYER_TYPE_ID)
|
||||
|
||||
@ -684,8 +684,7 @@ def test_runner_rejects_duplicate_tool_names_between_shell_and_other_layers(
|
||||
),
|
||||
)
|
||||
layer_providers = tuple(
|
||||
provider
|
||||
for provider in create_default_layer_providers(shellctl_entrypoint="http://unused")
|
||||
provider for provider in create_default_layer_providers(shellctl_entrypoint="http://unused")
|
||||
if provider.type_id != DIFY_SHELL_LAYER_TYPE_ID
|
||||
) + (shell_provider,)
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
|
||||
# Core service URLs
|
||||
CONSOLE_API_URL=
|
||||
SERVER_CONSOLE_API_URL=http://api:5001
|
||||
CONSOLE_WEB_URL=
|
||||
SERVICE_API_URL=
|
||||
TRIGGER_URL=http://localhost
|
||||
|
||||
@ -376,6 +376,7 @@ services:
|
||||
- ./.env
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
SERVER_CONSOLE_API_URL: ${SERVER_CONSOLE_API_URL:-http://api:5001}
|
||||
APP_API_URL: ${APP_API_URL:-}
|
||||
AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-}
|
||||
NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-}
|
||||
|
||||
@ -382,6 +382,7 @@ services:
|
||||
- ./.env
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
SERVER_CONSOLE_API_URL: ${SERVER_CONSOLE_API_URL:-http://api:5001}
|
||||
APP_API_URL: ${APP_API_URL:-}
|
||||
AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-}
|
||||
NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-}
|
||||
|
||||
@ -3406,14 +3406,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"web/app/components/workflow/block-selector/rag-tool-recommendations/index.tsx": {
|
||||
"no-restricted-properties": {
|
||||
"count": 3
|
||||
},
|
||||
"react/set-state-in-effect": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"web/app/components/workflow/block-selector/tool-picker.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
@ -4405,11 +4397,6 @@
|
||||
"count": 9
|
||||
}
|
||||
},
|
||||
"web/app/components/workflow/nodes/question-classifier/components/class-list.tsx": {
|
||||
"no-restricted-properties": {
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"web/app/components/workflow/nodes/question-classifier/use-single-run-form-params.ts": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 8
|
||||
|
||||
@ -4,14 +4,14 @@ import type { ViewType } from '@/app/components/workflow/block-selector/view-typ
|
||||
import type { OnSelectBlock } from '@/app/components/workflow/types'
|
||||
import { RiMoreLine } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import { Trans, useTranslation } from 'react-i18next'
|
||||
import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/arrows'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { getFormattedPlugin } from '@/app/components/plugins/marketplace/utils'
|
||||
import { useLocalStorage } from '@/hooks/use-local-storage'
|
||||
import Link from '@/next/link'
|
||||
import { useRAGRecommendedPlugins } from '@/service/use-tools'
|
||||
import { isServer } from '@/utils/client'
|
||||
import { getMarketplaceUrl } from '@/utils/var'
|
||||
import List from './list'
|
||||
|
||||
@ -29,26 +29,7 @@ const RAGToolRecommendations = ({
|
||||
onTagsChange,
|
||||
}: RAGToolRecommendationsProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [isCollapsed, setIsCollapsed] = useState<boolean>(() => {
|
||||
if (isServer)
|
||||
return false
|
||||
const stored = window.localStorage.getItem(STORAGE_KEY)
|
||||
return stored === 'true'
|
||||
})
|
||||
|
||||
useEffect(() => {
|
||||
if (isServer)
|
||||
return
|
||||
const stored = window.localStorage.getItem(STORAGE_KEY)
|
||||
if (stored !== null)
|
||||
setIsCollapsed(stored === 'true')
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
if (isServer)
|
||||
return
|
||||
window.localStorage.setItem(STORAGE_KEY, String(isCollapsed))
|
||||
}, [isCollapsed])
|
||||
const [isCollapsed, setIsCollapsed] = useLocalStorage<boolean>(STORAGE_KEY, false)
|
||||
|
||||
const {
|
||||
data: ragRecommendedPlugins,
|
||||
|
||||
@ -11,6 +11,7 @@ import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { ReactSortable } from 'react-sortablejs'
|
||||
import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/general'
|
||||
import { useLocalStorage } from '@/hooks/use-local-storage'
|
||||
import { useEdgesInteractions } from '../../../hooks'
|
||||
import AddButton from '../../_base/components/add-button'
|
||||
import Item from './class-item'
|
||||
@ -42,17 +43,8 @@ const ClassList: FC<Props> = ({
|
||||
const [shouldScrollToEnd, setShouldScrollToEnd] = useState(false)
|
||||
const prevListLength = useRef(list.length)
|
||||
const [collapsed, setCollapsed] = useState(false)
|
||||
const [isRenameHintDismissed, setIsRenameHintDismissed] = useState(() => {
|
||||
if (typeof window === 'undefined')
|
||||
return true
|
||||
|
||||
try {
|
||||
return window.localStorage.getItem(INLINE_LABEL_HINT_STORAGE_KEY) === 'true'
|
||||
}
|
||||
catch {
|
||||
return false
|
||||
}
|
||||
})
|
||||
const [storedRenameHintDismissed, setIsRenameHintDismissed] = useLocalStorage<boolean>(INLINE_LABEL_HINT_STORAGE_KEY)
|
||||
const isRenameHintDismissed = storedRenameHintDismissed ?? false
|
||||
|
||||
const handleClassChange = useCallback((index: number) => {
|
||||
return (value: Topic) => {
|
||||
@ -104,12 +96,7 @@ const ClassList: FC<Props> = ({
|
||||
return
|
||||
|
||||
setIsRenameHintDismissed(true)
|
||||
try {
|
||||
window.localStorage.setItem(INLINE_LABEL_HINT_STORAGE_KEY, 'true')
|
||||
}
|
||||
catch {
|
||||
}
|
||||
}, [isRenameHintDismissed])
|
||||
}, [isRenameHintDismissed, setIsRenameHintDismissed])
|
||||
|
||||
const shouldShowRenameHint = !readonly && !isRenameHintDismissed && list.some((item, index) => {
|
||||
return isDefaultClassLabel(item.label, index + 1, t)
|
||||
|
||||
43
web/config/__tests__/server.spec.ts
Normal file
43
web/config/__tests__/server.spec.ts
Normal file
@ -0,0 +1,43 @@
|
||||
// @vitest-environment node
|
||||
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('server-only', () => ({}))
|
||||
|
||||
const importServerConfig = async () => {
|
||||
vi.resetModules()
|
||||
return import('../server')
|
||||
}
|
||||
|
||||
describe('server config', () => {
|
||||
afterEach(() => {
|
||||
vi.unstubAllEnvs()
|
||||
})
|
||||
|
||||
it('should prefer the server-only console API URL for server requests', async () => {
|
||||
vi.stubEnv('SERVER_CONSOLE_API_URL', 'http://api:5001')
|
||||
vi.stubEnv('CONSOLE_API_URL', 'https://console.example.com')
|
||||
|
||||
const { SERVER_CONSOLE_API_PREFIX } = await importServerConfig()
|
||||
|
||||
expect(SERVER_CONSOLE_API_PREFIX).toBe('http://api:5001/console/api')
|
||||
})
|
||||
|
||||
it('should fall back to the public console API URL when no server-only URL is configured', async () => {
|
||||
vi.stubEnv('SERVER_CONSOLE_API_URL', '')
|
||||
vi.stubEnv('CONSOLE_API_URL', 'https://console.example.com')
|
||||
|
||||
const { SERVER_CONSOLE_API_PREFIX } = await importServerConfig()
|
||||
|
||||
expect(SERVER_CONSOLE_API_PREFIX).toBe('https://console.example.com/console/api')
|
||||
})
|
||||
|
||||
it('should remain unconfigured when both server URLs are empty', async () => {
|
||||
vi.stubEnv('SERVER_CONSOLE_API_URL', '')
|
||||
vi.stubEnv('CONSOLE_API_URL', '')
|
||||
|
||||
const { SERVER_CONSOLE_API_PREFIX } = await importServerConfig()
|
||||
|
||||
expect(SERVER_CONSOLE_API_PREFIX).toBeUndefined()
|
||||
})
|
||||
})
|
||||
@ -5,6 +5,8 @@ import 'server-only'
|
||||
const withoutTrailingSlash = (value: string) => value.endsWith('/') ? value.slice(0, -1) : value
|
||||
|
||||
// Server-side requests need the origin; browser requests should keep using NEXT_PUBLIC_API_PREFIX.
|
||||
export const SERVER_CONSOLE_API_PREFIX = env.CONSOLE_API_URL
|
||||
? `${withoutTrailingSlash(env.CONSOLE_API_URL)}/console/api`
|
||||
const serverConsoleApiUrl = env.SERVER_CONSOLE_API_URL || env.CONSOLE_API_URL
|
||||
|
||||
export const SERVER_CONSOLE_API_PREFIX = serverConsoleApiUrl
|
||||
? `${withoutTrailingSlash(serverConsoleApiUrl)}/console/api`
|
||||
: undefined
|
||||
|
||||
@ -161,6 +161,7 @@ const clientSchema = {
|
||||
export const env = createEnv({
|
||||
server: {
|
||||
CONSOLE_API_URL: z.string().optional(),
|
||||
SERVER_CONSOLE_API_URL: z.string().optional(),
|
||||
/**
|
||||
* Maximum length of segmentation tokens for indexing
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user