Compare commits

..

5 Commits

Author SHA1 Message Date
yyh
c39861b33e fix: configure server console api url 2026-06-02 22:48:11 +08:00
f591da7865 ci: ruff cover agent (#36949) 2026-06-02 11:40:19 +00:00
f19679b217 refactor: improve network error and allow verbose output (#36923) 2026-06-02 10:43:40 +00:00
b682591c7a refactor(web): migrate question classifier label hint storage (#36932)
Co-authored-by: lmlm <7487674+popsiclelmlm@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 10:28:50 +00:00
8f6b59feff refactor(web): migrate rag recommendations collapsed storage (#36940)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 09:08:51 +00:00
63 changed files with 945 additions and 929 deletions

View File

@ -51,6 +51,15 @@ jobs:
with:
files: |
api/**
- name: Check dify-agent inputs
if: github.event_name != 'merge_group'
id: dify-agent-changes
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
dify-agent/**/*.py
dify-agent/pyproject.toml
dify-agent/uv.lock
- if: github.event_name != 'merge_group'
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
@ -76,6 +85,17 @@ jobs:
# Format code
uv run ruff format ..
- if: github.event_name != 'merge_group' && steps.dify-agent-changes.outputs.any_changed == 'true'
run: |
cd dify-agent
uv sync --dev
# fmt first to avoid line too long
uv run ruff format .
# Fix lint errors
uv run ruff check --fix .
# Format code
uv run ruff format .
- name: count migration progress
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
run: |

View File

@ -147,7 +147,7 @@ class AppDescribeApi(AppReadResource):
class AppListApi(Resource):
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
@auth_router.guard_workspace(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, *, auth_data: AuthData):
try:
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))

View File

@ -1,12 +1,11 @@
from __future__ import annotations
from controllers.openapi.auth.conditions import (
EDITION_CE,
EDITION_EE,
HAS_ALLOWED_ROLES,
LOADED_APP_IS_PRIVATE,
PATH_HAS_APP_ID,
WEBAPP_AUTH_ENABLED,
WORKSPACE_MEMBERSHIP_REQUIRED,
)
from controllers.openapi.auth.data import Edition
from controllers.openapi.auth.flow import When
@ -16,18 +15,14 @@ from controllers.openapi.auth.prepare import (
load_app,
load_app_access_mode,
load_tenant,
load_tenant_from_request,
load_workspace_role,
resolve_external_user,
)
from controllers.openapi.auth.verify import (
check_acl,
check_app_api_enabled,
check_app_access,
check_membership,
check_private_app_permission,
check_scope,
check_workspace_mismatch,
check_workspace_role,
)
from libs.oauth_bearer import TokenType
@ -35,17 +30,13 @@ account_pipeline = AuthPipeline(
prepare=[
When(PATH_HAS_APP_ID, then=load_app),
When(PATH_HAS_APP_ID, then=load_tenant),
When(WORKSPACE_MEMBERSHIP_REQUIRED & ~PATH_HAS_APP_ID, then=load_tenant_from_request),
load_account,
When(HAS_ALLOWED_ROLES, then=load_workspace_role),
load_account, # all tokens here are account tokens
When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode),
],
auth=[
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
check_scope,
When((PATH_HAS_APP_ID | WORKSPACE_MEMBERSHIP_REQUIRED) & ~HAS_ALLOWED_ROLES, then=check_membership),
When(WORKSPACE_MEMBERSHIP_REQUIRED & PATH_HAS_APP_ID, then=check_workspace_mismatch),
When(HAS_ALLOWED_ROLES, then=check_workspace_role),
When(EDITION_CE & PATH_HAS_APP_ID, then=check_membership),
When(EDITION_EE & PATH_HAS_APP_ID & ~WEBAPP_AUTH_ENABLED, then=check_app_access),
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
],
@ -59,7 +50,6 @@ external_sso_pipeline = AuthPipeline(
When(PATH_HAS_APP_ID, then=load_app_access_mode),
],
auth=[
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
check_scope,
When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl),
When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission),

View File

@ -50,7 +50,4 @@ EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
WORKSPACE_MEMBERSHIP_REQUIRED = request_cond(lambda ctx: ctx.workspace_membership)
HAS_ALLOWED_ROLES = request_cond(lambda ctx: ctx.allowed_roles is not None)
LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE)

View File

@ -9,7 +9,7 @@ from werkzeug.exceptions import InternalServerError
from configs import dify_config
from libs.oauth_bearer import Scope, TokenType
from models.account import Account, Tenant, TenantAccountRole
from models.account import Account, Tenant
from models.model import App, EndUser
from services.enterprise.enterprise_service import WebAppAccessMode
@ -41,8 +41,6 @@ class RequestContext(BaseModel):
token_type: TokenType
scope: Scope | None = None
path_params: dict[str, str]
workspace_membership: bool = False
allowed_roles: frozenset[TenantAccountRole] | None = None
class AuthData(BaseModel):
@ -58,14 +56,10 @@ class AuthData(BaseModel):
external_identity: ExternalIdentity | None = None
path_params: dict[str, str] = Field(default_factory=dict)
allowed_roles: frozenset[TenantAccountRole] | None = None
app: App | None = None
tenant: Tenant | None = None
app_access_mode: WebAppAccessMode | None = None
tenant_role: TenantAccountRole | None = None
caller: Account | EndUser | None = None
caller_kind: Literal["account", "end_user"] | None = None

View File

@ -34,7 +34,6 @@ from libs.oauth_bearer import (
reset_auth_ctx,
set_auth_ctx,
)
from models.account import TenantAccountRole
from services.feature_service import FeatureService, LicenseStatus
@ -57,15 +56,11 @@ class AuthPipeline:
view: Callable,
*,
scope: Scope | None,
workspace_membership: bool = False,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Any:
req_ctx = RequestContext(
token_type=identity.token_type,
scope=scope,
path_params=dict(request.view_args or {}),
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
data = AuthData(
@ -76,7 +71,6 @@ class AuthPipeline:
scopes=frozenset(identity.scopes),
tenants=dict(identity.verified_tenants),
required_scope=scope,
allowed_roles=allowed_roles,
path_params=dict(req_ctx.path_params),
external_identity=(
ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer)
@ -127,41 +121,6 @@ class PipelineRouter:
scope: Scope | None = None,
allowed_token_types: frozenset[TokenType] | None = None,
edition: frozenset[Edition] | None = None,
workspace_membership: bool = False,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Callable:
return self._make_decorator(
scope=scope,
allowed_token_types=allowed_token_types,
edition=edition,
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
def guard_workspace(
self,
*,
scope: Scope | None = None,
allowed_token_types: frozenset[TokenType] | None = None,
edition: frozenset[Edition] | None = None,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Callable:
return self._make_decorator(
scope=scope,
allowed_token_types=allowed_token_types,
edition=edition,
workspace_membership=True,
allowed_roles=allowed_roles,
)
def _make_decorator(
self,
*,
scope: Scope | None,
allowed_token_types: frozenset[TokenType] | None,
edition: frozenset[Edition] | None,
workspace_membership: bool,
allowed_roles: frozenset[TenantAccountRole] | None,
) -> Callable:
def decorator(view: Callable) -> Callable:
@wraps(view)
@ -173,8 +132,6 @@ class PipelineRouter:
scope=scope,
allowed_token_types=allowed_token_types,
edition=edition,
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
return decorated
@ -190,8 +147,6 @@ class PipelineRouter:
scope: Scope | None,
allowed_token_types: frozenset[TokenType] | None,
edition: frozenset[Edition] | None,
workspace_membership: bool = False,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Any:
# 404 not 403 — this edition doesn't expose the feature at all
if edition is not None and current_edition() not in edition:
@ -227,15 +182,7 @@ class PipelineRouter:
if not license_checked and Edition.EE in route.required_edition:
_check_license()
return route.pipeline._run(
identity,
args,
kwargs,
view,
scope=scope,
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
return route.pipeline._run(identity, args, kwargs, view, scope=scope)
def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool:

View File

@ -1,8 +1,5 @@
from __future__ import annotations
import uuid
from flask import request
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound, Unauthorized
from controllers.openapi.auth.data import AuthData
@ -16,18 +13,16 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppAcce
def load_app(data: AuthData) -> None:
if data.app is not None:
return
app_id = data.path_params["app_id"]
app = AppService.get_app_by_id(db.session, app_id)
if not app or app.status != "normal":
raise NotFound("app not found")
if not app.enable_api:
raise Forbidden("service_api_disabled")
data.app = app
def load_tenant(data: AuthData) -> None:
if data.tenant is not None:
return
if data.app is None:
raise InternalServerError("pipeline_invariant_violated: app not loaded before load_tenant")
tenant = TenantService.get_tenant_by_id(db.session, str(data.app.tenant_id))
@ -36,25 +31,7 @@ def load_tenant(data: AuthData) -> None:
data.tenant = tenant
def load_tenant_from_request(data: AuthData) -> None:
if data.tenant is not None:
return
workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
if not workspace_id:
raise NotFound("workspace not found")
try:
uuid.UUID(workspace_id)
except ValueError:
raise NotFound("workspace not found")
tenant = TenantService.get_tenant_by_id(db.session, workspace_id)
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
raise NotFound("workspace not found")
data.tenant = tenant
def load_account(data: AuthData) -> None:
if data.caller is not None:
return
account = AccountService.get_account_by_id(db.session, str(data.account_id))
if account is None:
raise Unauthorized("account not found")
@ -64,19 +41,6 @@ def load_account(data: AuthData) -> None:
data.caller_kind = "account"
def load_workspace_role(data: AuthData) -> None:
if data.tenant_role is not None:
return
if data.tenant is None or data.account_id is None:
return
if data.caller is not None and getattr(data.caller, "status", None) != "active":
return
role = TenantService.get_account_role_in_tenant(db.session, str(data.account_id), str(data.tenant.id))
if role is None:
return
data.tenant_role = role
def resolve_external_user(data: AuthData) -> None:
if data.tenant is None or data.app is None or data.external_identity is None:
raise Unauthorized("missing context for external user resolution")

View File

@ -0,0 +1,77 @@
"""Workspace role gate.
Layered on top of `validate_bearer` + `accept_subjects(SubjectType.ACCOUNT)`
for routes whose access depends on the caller's `TenantAccountJoin.role`
in the workspace named by the `workspace_id` path parameter.
Usage::
@openapi_ns.route("/workspaces/<string:workspace_id>/members")
class Members(Resource):
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
@require_workspace_role() # any member
def get(self, workspace_id: str): ...
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def post(self, workspace_id: str): ...
Non-member callers get 404 (matching `GET /openapi/v1/workspaces/<id>`)
so workspace IDs do not leak across tenants. A member without one of the
allowed roles gets 403.
"""
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TypeVar
from werkzeug.exceptions import Forbidden, NotFound
from extensions.ext_database import db
from libs.oauth_bearer import try_get_auth_ctx
from models.account import TenantAccountRole
from services.account_service import TenantService
F = TypeVar("F", bound=Callable[..., object])
def require_workspace_role(*allowed_roles: TenantAccountRole) -> Callable[[F], F]:
"""Gate a route on the caller's role in ``workspace_id``.
Pass no roles to require only membership. Pass one or more roles to
require the caller's role be in that set.
"""
allowed = frozenset(allowed_roles)
def deco(fn: F) -> F:
@wraps(fn)
def wrapper(*args: object, **kwargs: object) -> object:
ctx = try_get_auth_ctx()
if ctx is None or ctx.account_id is None:
raise RuntimeError(
"require_workspace_role called without account-bearer context; "
"stack validate_bearer + accept_subjects(SubjectType.ACCOUNT) above it"
)
workspace_id = kwargs.get("workspace_id")
if not workspace_id:
raise RuntimeError("require_workspace_role expects a 'workspace_id' route parameter")
role = TenantService.get_account_role_in_tenant(db.session, str(ctx.account_id), str(workspace_id))
if role is None:
raise NotFound("workspace not found")
if allowed and role not in allowed:
raise Forbidden("insufficient workspace role")
return fn(*args, **kwargs)
return wrapper # type: ignore[return-value]
return deco

View File

@ -1,7 +1,6 @@
from __future__ import annotations
from flask import request
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized, UnprocessableEntity
from werkzeug.exceptions import Forbidden, Unauthorized
from controllers.openapi.auth.data import AuthData
from extensions.ext_database import db
@ -31,30 +30,6 @@ def check_membership(data: AuthData) -> None:
)
def check_workspace_mismatch(data: AuthData) -> None:
if data.tenant is None:
return
request_workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
if request_workspace_id and request_workspace_id != str(data.tenant.id):
raise UnprocessableEntity("workspace_id does not match app's workspace")
def check_workspace_role(data: AuthData) -> None:
if data.allowed_roles is None:
return
if data.tenant_role is None:
raise NotFound("workspace not found")
if data.tenant_role not in data.allowed_roles:
raise Forbidden("insufficient workspace role")
def check_app_api_enabled(data: AuthData) -> None:
if data.app is None:
return
if not data.app.enable_api:
raise Forbidden("service_api_disabled")
def check_app_access(data: AuthData) -> None:
if data.tenant is None:
return

View File

@ -5,8 +5,9 @@ endpoints. Account bearers (dfoa_) see every tenant they're a member of.
External SSO bearers (dfoe_) have no account_id and so see an empty list —
that matches /openapi/v1/account.
Member-management endpoints use ``guard_workspace`` which enforces
workspace membership and optional role requirements via the auth pipeline.
Member-management endpoints are gated by both `accept_subjects` (SSO out)
and `require_workspace_role` (membership / role lookup against the path's
``workspace_id``).
"""
from __future__ import annotations
@ -36,6 +37,7 @@ from controllers.openapi._models import (
)
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.role_gate import require_workspace_role
from extensions.ext_database import db
from libs.oauth_bearer import Scope, TokenType
from models import Account, Tenant, TenantAccountJoin
@ -150,7 +152,8 @@ class WorkspaceSwitchApi(Resource):
"""
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role()
def post(self, workspace_id: str, *, auth_data: AuthData):
account = _load_account(auth_data.account_id)
@ -176,7 +179,8 @@ class WorkspaceMembersApi(Resource):
@openapi_ns.doc(params=query_params_from_model(MemberListQuery))
@openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__])
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role()
def get(self, workspace_id: str, *, auth_data: AuthData):
try:
query = MemberListQuery.model_validate(request.args.to_dict(flat=True))
@ -198,11 +202,8 @@ class WorkspaceMembersApi(Resource):
@openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__])
@openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__])
@auth_router.guard_workspace(
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def post(self, workspace_id: str, *, auth_data: AuthData):
payload = _validate_body(MemberInvitePayload)
inviter = _load_account(auth_data.account_id)
@ -252,11 +253,8 @@ class WorkspaceMemberApi(Resource):
"""
@openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__])
@auth_router.guard_workspace(
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
operator = _load_account(auth_data.account_id)
tenant = _load_tenant(workspace_id)
@ -286,11 +284,8 @@ class WorkspaceMemberRoleApi(Resource):
@openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__])
@openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__])
@auth_router.guard_workspace(
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
payload = _validate_body(MemberRoleUpdatePayload)
operator = _load_account(auth_data.account_id)

View File

@ -1329,9 +1329,9 @@ class TenantService:
) -> TenantAccountRole | None:
"""Return the caller's role in ``tenant_id``, or ``None`` if not a member.
Backs the openapi auth pipeline's ``load_workspace_role`` prepare step:
``None`` is treated as non-member (the pipeline maps it to 404 — no
cross-tenant ID leak) and an out-of-set role to 403.
Backs ``controllers.openapi.auth.role_gate.require_workspace_role``:
the gate maps ``None`` to 404 (non-member — no cross-tenant ID leak)
and an out-of-set role to 403, so it never touches the ORM itself.
``None``/empty ``account_id`` short-circuits to ``None`` so SSO
bearers (no account) collapse to the non-member path. Mirrors the

View File

@ -16,20 +16,20 @@ def test_auth_router_is_pipeline_router():
assert isinstance(auth_router, PipelineRouter)
def test_account_pipeline_prepare_has_six_entries():
assert len(account_pipeline._prepare) == 6
def test_account_pipeline_prepare_has_four_entries():
assert len(account_pipeline._prepare) == 4
def test_account_auth_list_has_seven_entries():
assert len(account_pipeline._auth) == 7
def test_account_auth_list_has_five_entries():
assert len(account_pipeline._auth) == 5
def test_external_sso_pipeline_prepare_has_four_entries():
assert len(external_sso_pipeline._prepare) == 4
def test_external_sso_auth_list_has_four_entries():
assert len(external_sso_pipeline._auth) == 4
def test_external_sso_auth_list_has_three_entries():
assert len(external_sso_pipeline._auth) == 3
def test_account_pipeline_has_unconditional_load_account():
@ -41,14 +41,17 @@ def test_external_sso_pipeline_all_prepare_entries_are_when():
assert all(isinstance(s, When) for s in external_sso_pipeline._prepare)
def test_account_pipeline_has_one_unconditional_auth_step():
non_when = [s for s in account_pipeline._auth if not isinstance(s, When)]
assert len(non_when) == 1
def test_first_auth_entry_is_check_scope_in_both_pipelines():
assert not isinstance(account_pipeline._auth[0], When)
assert not isinstance(external_sso_pipeline._auth[0], When)
def test_external_sso_pipeline_has_one_unconditional_auth_step():
non_when = [s for s in external_sso_pipeline._auth if not isinstance(s, When)]
assert len(non_when) == 1
def test_remaining_auth_entries_are_when_for_account():
assert all(isinstance(s, When) for s in account_pipeline._auth[1:])
def test_remaining_auth_entries_are_when_for_external_sso():
assert all(isinstance(s, When) for s in external_sso_pipeline._auth[1:])
def test_router_routes_contain_both_token_types():

View File

@ -4,13 +4,11 @@ from controllers.openapi.auth.conditions import (
EDITION_CE,
EDITION_EE,
EDITION_SAAS,
HAS_ALLOWED_ROLES,
LOADED_APP_IS_PRIVATE,
PATH_HAS_APP_ID,
TOKEN_IS_OAUTH_ACCOUNT,
TOKEN_IS_OAUTH_EXTERNAL_SSO,
WEBAPP_AUTH_ENABLED,
WORKSPACE_MEMBERSHIP_REQUIRED,
Cond,
config_cond,
data_cond,
@ -18,15 +16,13 @@ from controllers.openapi.auth.conditions import (
)
from controllers.openapi.auth.data import AuthData, Edition, RequestContext
from libs.oauth_bearer import TokenType
from models.account import TenantAccountRole
from services.enterprise.enterprise_service import WebAppAccessMode
def _ctx(token_type=TokenType.OAUTH_ACCOUNT, path_params=None, **kwargs):
def _ctx(token_type=TokenType.OAUTH_ACCOUNT, path_params=None):
return RequestContext(
token_type=token_type,
path_params=path_params or {},
**kwargs,
)
@ -145,28 +141,3 @@ def test_loaded_app_is_private():
assert LOADED_APP_IS_PRIVATE(_ctx(), data_public) is False
assert LOADED_APP_IS_PRIVATE(_ctx(), data_none) is False
assert LOADED_APP_IS_PRIVATE(_ctx(), None) is False
def test_workspace_membership_required_true():
assert WORKSPACE_MEMBERSHIP_REQUIRED(_ctx(workspace_membership=True)) is True
def test_workspace_membership_required_false():
assert WORKSPACE_MEMBERSHIP_REQUIRED(_ctx(workspace_membership=False)) is False
def test_workspace_membership_required_default():
assert WORKSPACE_MEMBERSHIP_REQUIRED(_ctx()) is False
def test_has_allowed_roles_true():
ctx = _ctx(allowed_roles=frozenset({TenantAccountRole.OWNER}))
assert HAS_ALLOWED_ROLES(ctx) is True
def test_has_allowed_roles_false():
assert HAS_ALLOWED_ROLES(_ctx(allowed_roles=None)) is False
def test_has_allowed_roles_default():
assert HAS_ALLOWED_ROLES(_ctx()) is False

View File

@ -115,69 +115,3 @@ def test_auth_data_token_id_optional():
scopes=frozenset(),
)
assert data.token_id is None
def test_request_context_workspace_membership_default_false():
ctx = RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={})
assert ctx.workspace_membership is False
def test_request_context_workspace_membership_set():
ctx = RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={}, workspace_membership=True)
assert ctx.workspace_membership is True
def test_request_context_allowed_roles_default_none():
ctx = RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={})
assert ctx.allowed_roles is None
def test_request_context_allowed_roles_set():
from models.account import TenantAccountRole
roles = frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN})
ctx = RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={}, allowed_roles=roles)
assert ctx.allowed_roles == roles
def test_auth_data_allowed_roles_default_none():
data = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
token_hash="abc",
scopes=frozenset(),
)
assert data.allowed_roles is None
def test_auth_data_allowed_roles_set():
from models.account import TenantAccountRole
roles = frozenset({TenantAccountRole.ADMIN})
data = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
token_hash="abc",
scopes=frozenset(),
allowed_roles=roles,
)
assert data.allowed_roles == roles
def test_auth_data_tenant_role_default_none():
data = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
token_hash="abc",
scopes=frozenset(),
)
assert data.tenant_role is None
def test_auth_data_tenant_role_set():
from models.account import TenantAccountRole
data = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
token_hash="abc",
scopes=frozenset(),
tenant_role=TenantAccountRole.ADMIN,
)
assert data.tenant_role == TenantAccountRole.ADMIN

View File

@ -247,60 +247,6 @@ def test_guard_populates_external_identity_from_subject_email(app):
assert received["data"].external_identity.issuer == "https://idp.example.com"
def test_guard_workspace_sets_membership_and_roles(app):
from models.account import TenantAccountRole
router = _make_router()
received = {}
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
):
mock_auth.return_value.authenticate.return_value = _fake_identity()
roles = frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN})
@router.guard_workspace(
scope=Scope.FULL,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=roles,
)
def view(*, auth_data):
received["data"] = auth_data
view()
assert isinstance(received["data"], AuthData)
assert received["data"].allowed_roles == roles
def test_guard_workspace_without_roles(app):
router = _make_router()
received = {}
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
):
mock_auth.return_value.authenticate.return_value = _fake_identity()
@router.guard_workspace(scope=Scope.FULL)
def view(*, auth_data):
received["data"] = auth_data
view()
assert isinstance(received["data"], AuthData)
assert received["data"].allowed_roles is None
def test_guard_no_external_identity_when_subject_email_absent(app):
router = _make_router()
received = {}

View File

@ -2,7 +2,6 @@ import uuid
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from controllers.openapi.auth.data import AuthData, ExternalIdentity
@ -11,12 +10,9 @@ from controllers.openapi.auth.prepare import (
load_app,
load_app_access_mode,
load_tenant,
load_tenant_from_request,
load_workspace_role,
resolve_external_user,
)
from libs.oauth_bearer import TokenType
from models.account import TenantAccountRole
def _make_auth_data(**kwargs) -> AuthData:
@ -58,21 +54,14 @@ def test_load_app_raises_not_found_when_not_normal():
load_app(data)
def test_load_app_stashes_app_even_when_api_disabled():
def test_load_app_raises_forbidden_when_api_disabled():
app = MagicMock()
app.status = "normal"
app.enable_api = False
data = _make_auth_data(path_params={"app_id": "abc"})
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app):
load_app(data)
assert data.app is app
def test_load_app_skips_when_already_set():
existing_app = MagicMock()
data = _make_auth_data(app=existing_app, path_params={"app_id": "abc"})
load_app(data)
assert data.app is existing_app
with pytest.raises(Forbidden):
load_app(data)
def test_load_tenant_writes_tenant():
@ -86,13 +75,6 @@ def test_load_tenant_writes_tenant():
assert data.tenant is tenant
def test_load_tenant_skips_when_already_set():
existing_tenant = MagicMock()
data = _make_auth_data(app=MagicMock(), tenant=existing_tenant)
load_tenant(data)
assert data.tenant is existing_tenant
def test_load_tenant_raises_forbidden_when_archived():
from models.account import TenantStatus
@ -133,13 +115,6 @@ def test_load_account_writes_caller():
assert data.caller_kind == "account"
def test_load_account_skips_when_already_set():
existing_caller = MagicMock()
data = _make_auth_data(account_id=uuid.uuid4(), caller=existing_caller)
load_account(data)
assert data.caller is existing_caller
def test_load_account_sets_current_tenant_when_tenant_present():
account = MagicMock()
tenant = MagicMock()
@ -206,143 +181,3 @@ def test_load_app_access_mode_no_op_when_app_missing():
data = _make_auth_data()
load_app_access_mode(data)
assert data.app_access_mode is None
@pytest.fixture
def flask_app():
return Flask(__name__)
def test_load_tenant_from_request_from_path_params(flask_app):
tenant = MagicMock()
tenant.status = "normal"
wid = str(uuid.uuid4())
data = _make_auth_data(path_params={"workspace_id": wid})
with flask_app.test_request_context("/test"):
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant):
load_tenant_from_request(data)
assert data.tenant is tenant
def test_load_tenant_from_request_from_query_param(flask_app):
tenant = MagicMock()
tenant.status = "normal"
wid = str(uuid.uuid4())
data = _make_auth_data(path_params={})
with flask_app.test_request_context(f"/test?workspace_id={wid}"):
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant):
load_tenant_from_request(data)
assert data.tenant is tenant
def test_load_tenant_from_request_skips_when_already_set(flask_app):
existing_tenant = MagicMock()
data = _make_auth_data(tenant=existing_tenant, path_params={})
with flask_app.test_request_context("/test"):
load_tenant_from_request(data)
assert data.tenant is existing_tenant
def test_load_tenant_from_request_raises_not_found_when_no_id(flask_app):
data = _make_auth_data(path_params={})
with flask_app.test_request_context("/test"):
with pytest.raises(NotFound):
load_tenant_from_request(data)
def test_load_tenant_from_request_raises_not_found_when_missing(flask_app):
wid = str(uuid.uuid4())
data = _make_auth_data(path_params={"workspace_id": wid})
with flask_app.test_request_context("/test"):
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=None):
with pytest.raises(NotFound):
load_tenant_from_request(data)
def test_load_tenant_from_request_raises_not_found_when_archived(flask_app):
from models.account import TenantStatus
tenant = MagicMock()
tenant.status = TenantStatus.ARCHIVE
wid = str(uuid.uuid4())
data = _make_auth_data(path_params={"workspace_id": wid})
with flask_app.test_request_context("/test"):
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant):
with pytest.raises(NotFound):
load_tenant_from_request(data)
def test_load_tenant_from_request_raises_not_found_when_invalid_uuid(flask_app):
data = _make_auth_data(path_params={"workspace_id": "not-a-uuid"})
with flask_app.test_request_context("/test"):
with pytest.raises(NotFound):
load_tenant_from_request(data)
# --- load_workspace_role ---
def test_load_workspace_role_stashes_role():
tenant = MagicMock()
tenant.id = uuid.uuid4()
caller = MagicMock()
caller.status = "active"
data = _make_auth_data(account_id=uuid.uuid4(), tenant=tenant, caller=caller)
with patch(
"controllers.openapi.auth.prepare.TenantService.get_account_role_in_tenant",
return_value=TenantAccountRole.ADMIN,
):
load_workspace_role(data)
assert data.tenant_role == TenantAccountRole.ADMIN
def test_load_workspace_role_none_when_not_member():
tenant = MagicMock()
tenant.id = uuid.uuid4()
caller = MagicMock()
caller.status = "active"
data = _make_auth_data(account_id=uuid.uuid4(), tenant=tenant, caller=caller)
with patch(
"controllers.openapi.auth.prepare.TenantService.get_account_role_in_tenant",
return_value=None,
):
load_workspace_role(data)
assert data.tenant_role is None
def test_load_workspace_role_none_when_account_inactive():
tenant = MagicMock()
tenant.id = uuid.uuid4()
caller = MagicMock()
caller.status = "banned"
data = _make_auth_data(account_id=uuid.uuid4(), tenant=tenant, caller=caller)
load_workspace_role(data)
assert data.tenant_role is None
def test_load_workspace_role_skips_when_already_set():
tenant = MagicMock()
tenant.id = uuid.uuid4()
caller = MagicMock()
caller.status = "active"
data = _make_auth_data(
account_id=uuid.uuid4(),
tenant=tenant,
caller=caller,
tenant_role=TenantAccountRole.OWNER,
)
load_workspace_role(data)
assert data.tenant_role == TenantAccountRole.OWNER
def test_load_workspace_role_skips_when_tenant_missing():
data = _make_auth_data(account_id=uuid.uuid4())
load_workspace_role(data)
assert data.tenant_role is None
def test_load_workspace_role_skips_when_account_id_missing():
tenant = MagicMock()
data = _make_auth_data(tenant=tenant, account_id=None)
load_workspace_role(data)
assert data.tenant_role is None

View File

@ -0,0 +1,330 @@
"""Role-gate tests.
The decorator wraps `validate_bearer` + `accept_subjects` and must:
- 404 when caller is not a member of ``workspace_id`` (parity with
`GET /openapi/v1/workspaces/<id>`; prevents tenant-id existence leak)
- 403 when caller IS a member but their role is not in the allowed set
- pass through when role matches (or when no role restriction given)
- raise RuntimeError on missing auth context / account_id / workspace_id —
those are wiring bugs, not user-driven failures
Identity is read from the openapi auth ContextVar — the slot
`validate_bearer` publishes — so these tests seed it via `_seed`
(``set_auth_ctx``), NOT ``flask.g``. `test_seeding_only_flask_g_*`
locks in that ``g`` is *not* a valid identity source.
"""
from __future__ import annotations
import uuid
from contextlib import contextmanager
from datetime import UTC, datetime
from unittest.mock import patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
from controllers.openapi.auth.role_gate import require_workspace_role
from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx
from models.account import TenantAccountRole
# Tokens from `_seed`'s `set_auth_ctx` calls, drained after each test so a
# published identity can't leak into the next (the ContextVar is module-global
# and worker threads are reused). Seed via `_seed(...)`, never `flask.g`.
_seed_tokens: list = []
def _seed(ctx: AuthContext) -> None:
_seed_tokens.append(set_auth_ctx(ctx))
@pytest.fixture(autouse=True)
def _reset_auth_ctx():
yield
while _seed_tokens:
reset_auth_ctx(_seed_tokens.pop())
def _account_ctx(account_id: uuid.UUID | None = None) -> AuthContext:
return AuthContext(
subject_type=SubjectType.ACCOUNT,
subject_email="user@example.com",
subject_issuer="dify:account",
account_id=account_id or uuid.uuid4(),
client_id="difyctl",
scopes=frozenset({Scope.FULL}),
token_id=uuid.uuid4(),
token_type=TokenType.OAUTH_ACCOUNT,
expires_at=datetime.now(UTC),
token_hash="h1",
verified_tenants={},
)
def _sso_ctx() -> AuthContext:
return AuthContext(
subject_type=SubjectType.EXTERNAL_SSO,
subject_email="sso@partner.com",
subject_issuer="https://idp.partner.com",
account_id=None,
client_id="difyctl",
scopes=frozenset({Scope.APPS_RUN}),
token_id=uuid.uuid4(),
token_type=TokenType.OAUTH_EXTERNAL_SSO,
expires_at=datetime.now(UTC),
token_hash="h2",
verified_tenants={},
)
@contextmanager
def _stub_role(role: TenantAccountRole | None):
"""Stub the service-layer membership lookup the gate delegates to.
The gate no longer issues SQL itself — it calls
``TenantService.get_account_role_in_tenant`` and acts purely on the
returned role (``None`` → non-member). These tests pin that behaviour;
the query itself is covered in ``TestTenantService``.
"""
with patch(
"controllers.openapi.auth.role_gate.TenantService.get_account_role_in_tenant",
return_value=role,
) as mocked:
yield mocked
# ---------------------------------------------------------------------------
# Non-member → 404
# ---------------------------------------------------------------------------
def test_non_member_gets_404():
app = Flask(__name__)
workspace_id = str(uuid.uuid4())
@require_workspace_role()
def view(workspace_id: str) -> str:
return "ok"
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"):
_seed(_account_ctx())
with _stub_role(None):
with pytest.raises(NotFound):
view(workspace_id=workspace_id)
# ---------------------------------------------------------------------------
# Member with insufficient role → 403
# ---------------------------------------------------------------------------
def test_normal_member_blocked_when_admin_required():
app = Flask(__name__)
workspace_id = str(uuid.uuid4())
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def view(workspace_id: str) -> str:
return "ok"
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/members"):
_seed(_account_ctx())
with _stub_role(TenantAccountRole.NORMAL):
with pytest.raises(Forbidden):
view(workspace_id=workspace_id)
def test_editor_blocked_when_admin_required():
app = Flask(__name__)
workspace_id = str(uuid.uuid4())
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def view(workspace_id: str) -> str:
return "ok"
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/members"):
_seed(_account_ctx())
with _stub_role(TenantAccountRole.EDITOR):
with pytest.raises(Forbidden):
view(workspace_id=workspace_id)
# ---------------------------------------------------------------------------
# Member with allowed role → pass
# ---------------------------------------------------------------------------
def test_admin_passes_when_admin_required():
app = Flask(__name__)
workspace_id = str(uuid.uuid4())
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def view(workspace_id: str) -> str:
return "ok"
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/members"):
_seed(_account_ctx())
with _stub_role(TenantAccountRole.ADMIN):
assert view(workspace_id=workspace_id) == "ok"
def test_owner_passes_when_admin_required():
app = Flask(__name__)
workspace_id = str(uuid.uuid4())
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def view(workspace_id: str) -> str:
return "ok"
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/members"):
_seed(_account_ctx())
with _stub_role(TenantAccountRole.OWNER):
assert view(workspace_id=workspace_id) == "ok"
# ---------------------------------------------------------------------------
# Membership-only (no role restriction)
# ---------------------------------------------------------------------------
def test_membership_only_passes_for_any_role():
app = Flask(__name__)
workspace_id = str(uuid.uuid4())
@require_workspace_role()
def view(workspace_id: str) -> str:
return "ok"
for role in (
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
):
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"):
_seed(_account_ctx())
with _stub_role(role):
assert view(workspace_id=workspace_id) == "ok"
def test_membership_only_still_404s_non_member():
app = Flask(__name__)
workspace_id = str(uuid.uuid4())
@require_workspace_role()
def view(workspace_id: str) -> str:
return "ok"
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"):
_seed(_account_ctx())
with _stub_role(None):
with pytest.raises(NotFound):
view(workspace_id=workspace_id)
# ---------------------------------------------------------------------------
# Lookup is scoped to the caller's account_id and the URL workspace_id
# ---------------------------------------------------------------------------
def test_lookup_is_scoped_to_caller_and_workspace():
"""The decorator must delegate the lookup keyed on
`(caller's account_id, URL workspace_id)` — otherwise a member of
workspace A could quietly hit endpoints for workspace B. Assert the
exact arguments handed to the service; the SQL those arguments compile
to is pinned in ``TestTenantService.test_get_account_role_in_tenant_*``.
"""
app = Flask(__name__)
account_id = uuid.uuid4()
workspace_id = str(uuid.uuid4())
@require_workspace_role()
def view(workspace_id: str) -> str:
return "ok"
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"):
_seed(_account_ctx(account_id=account_id))
with _stub_role(TenantAccountRole.NORMAL) as mocked:
view(workspace_id=workspace_id)
_session, passed_account_id, passed_workspace_id = mocked.call_args.args
assert passed_account_id == str(account_id)
assert passed_workspace_id == workspace_id
# ---------------------------------------------------------------------------
# Wiring bugs surface as RuntimeError (loud), not 403 (silent)
# ---------------------------------------------------------------------------
def test_missing_auth_ctx_is_runtime_error():
app = Flask(__name__)
workspace_id = str(uuid.uuid4())
@require_workspace_role()
def view(workspace_id: str) -> str:
return "ok"
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"):
with pytest.raises(RuntimeError):
view(workspace_id=workspace_id)
def test_seeding_only_flask_g_does_not_satisfy_gate():
"""Regression — pins the identity source to the ContextVar, not ``flask.g``.
Production fills the ContextVar (``validate_bearer`` → ``set_auth_ctx``)
and never touches ``g.auth_ctx``. An earlier revision of this gate read
``g.auth_ctx``, so every real request raised RuntimeError → 500 while the
suite stayed green (it seeded ``g`` directly). Here we seed ONLY ``g`` and
leave the ContextVar empty: the gate must still raise, proving it does not
accept ``g`` as an identity source. Reading ``g`` again would let the
membership lookup run (stubbed to succeed) and this would fail.
"""
from flask import g
app = Flask(__name__)
workspace_id = str(uuid.uuid4())
@require_workspace_role()
def view(workspace_id: str) -> str:
return "ok"
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"):
g.auth_ctx = _account_ctx() # the wrong slot — must be ignored
with _stub_role(TenantAccountRole.OWNER):
with pytest.raises(RuntimeError):
view(workspace_id=workspace_id)
def test_sso_caller_is_runtime_error():
"""External SSO context has account_id=None — the caller stacked the
role gate without `accept_subjects(SubjectType.ACCOUNT)`. That's a
wiring bug, surface it as RuntimeError rather than 404 the SSO user."""
app = Flask(__name__)
workspace_id = str(uuid.uuid4())
@require_workspace_role()
def view(workspace_id: str) -> str:
return "ok"
with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"):
_seed(_sso_ctx())
with pytest.raises(RuntimeError):
view(workspace_id=workspace_id)
def test_missing_workspace_id_kwarg_is_runtime_error():
app = Flask(__name__)
@require_workspace_role()
def view() -> str:
return "ok"
with app.test_request_context("/openapi/v1/foo"):
_seed(_account_ctx())
with pytest.raises(RuntimeError):
view()

View File

@ -2,22 +2,18 @@ import uuid
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from werkzeug.exceptions import Forbidden, Unauthorized
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.verify import (
check_acl,
check_app_access,
check_app_api_enabled,
check_membership,
check_private_app_permission,
check_scope,
check_workspace_mismatch,
check_workspace_role,
)
from libs.oauth_bearer import Scope, TokenType
from models.account import Tenant, TenantAccountRole
from models.account import Tenant
from models.model import App
from services.enterprise.enterprise_service import WebAppAccessMode
@ -144,92 +140,3 @@ def test_check_private_app_permission_passes_when_allowed():
target = "controllers.openapi.auth.verify.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp"
with patch(target, return_value=True):
check_private_app_permission(data)
# --- check_workspace_mismatch ---
@pytest.fixture
def flask_app():
return Flask(__name__)
def test_check_workspace_mismatch_passes_when_tenant_none(flask_app):
with flask_app.test_request_context("/test"):
check_workspace_mismatch(_data(tenant=None))
def test_check_workspace_mismatch_passes_when_ids_match(flask_app):
tenant = MagicMock(spec=Tenant)
tid = uuid.uuid4()
tenant.id = tid
with flask_app.test_request_context(f"/test?workspace_id={tid}"):
check_workspace_mismatch(_data(tenant=tenant, path_params={}))
def test_check_workspace_mismatch_raises_422_on_mismatch(flask_app):
from werkzeug.exceptions import UnprocessableEntity
tenant = MagicMock(spec=Tenant)
tenant.id = uuid.uuid4()
other_id = uuid.uuid4()
with flask_app.test_request_context(f"/test?workspace_id={other_id}"):
with pytest.raises(UnprocessableEntity):
check_workspace_mismatch(_data(tenant=tenant, path_params={}))
def test_check_workspace_mismatch_passes_when_no_request_workspace_id(flask_app):
tenant = MagicMock(spec=Tenant)
tenant.id = uuid.uuid4()
with flask_app.test_request_context("/test"):
check_workspace_mismatch(_data(tenant=tenant, path_params={}))
# --- check_workspace_role ---
def test_check_workspace_role_passes_when_allowed_roles_none():
check_workspace_role(_data(allowed_roles=None))
def test_check_workspace_role_raises_not_found_when_not_member():
data = _data(tenant_role=None, allowed_roles=frozenset({TenantAccountRole.ADMIN}))
with pytest.raises(NotFound):
check_workspace_role(data)
def test_check_workspace_role_raises_forbidden_when_wrong_role():
data = _data(
tenant_role=TenantAccountRole.EDITOR,
allowed_roles=frozenset({TenantAccountRole.OWNER}),
)
with pytest.raises(Forbidden, match="insufficient workspace role"):
check_workspace_role(data)
def test_check_workspace_role_passes_when_role_allowed():
data = _data(
tenant_role=TenantAccountRole.ADMIN,
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
check_workspace_role(data)
# --- check_app_api_enabled ---
def test_check_app_api_enabled_passes_when_enabled():
app = MagicMock(spec=App)
app.enable_api = True
check_app_api_enabled(_data(app=app))
def test_check_app_api_enabled_raises_forbidden_when_disabled():
app = MagicMock(spec=App)
app.enable_api = False
with pytest.raises(Forbidden, match="service_api_disabled"):
check_app_api_enabled(_data(app=app))
def test_check_app_api_enabled_passes_when_app_none():
check_app_api_enabled(_data(app=None))

View File

@ -9,18 +9,7 @@ from controllers.openapi.auth.pipeline import PipelineRouter
from libs.oauth_bearer import Scope, TokenType
def _stub_execute(
self,
args,
kwargs,
view,
*,
scope=None,
allowed_token_types=None,
edition=None,
workspace_membership=False,
allowed_roles=None,
):
def _stub_execute(self, args, kwargs, view, *, scope=None, allowed_token_types=None, edition=None):
"""Bypass all auth logic; inject minimal AuthData and call the view directly."""
kwargs["auth_data"] = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
@ -29,7 +18,6 @@ def _stub_execute(
token_id=uuid.uuid4(),
scopes=frozenset({Scope.FULL}),
required_scope=scope,
allowed_roles=allowed_roles,
)
return view(*args, **kwargs)

View File

@ -9,10 +9,9 @@ Coverage:
Auth-pipeline plumbing is bypassed via the `bypass_pipeline` fixture from
conftest.py; the bearer identity is seeded into the openapi auth ContextVar
via `_seed` (the slot `validate_bearer` publishes). Tests that exercise
endpoint *bodies* skip the single `guard_workspace` decorator via
``__wrapped__`` — membership and role enforcement live in the auth pipeline
and are covered in `auth/test_prepare.py` and `auth/test_verify.py`.
via `_seed` (the slot `validate_bearer` publishes), and the role gate's DB
lookup is mocked. Tests that exercise endpoint *bodies* skip the decorators
via ``__wrapped__`` since those layers are covered in `auth/test_role_gate.py`.
"""
from __future__ import annotations
@ -269,7 +268,7 @@ def test_switch_returns_workspace_detail_with_current_true(app, bypass_pipeline,
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/switch", method="POST"):
_seed(_auth_ctx(account_id=acct_id))
body, status = api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
assert status == 200
assert body["id"] == ws_id
@ -297,7 +296,7 @@ def test_switch_404s_when_service_raises_account_not_link_tenant(app, bypass_pip
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/switch", method="POST"):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(NotFound):
api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
# ---------------------------------------------------------------------------
@ -331,7 +330,7 @@ def test_members_list_returns_normalized_rows(app, bypass_pipeline, monkeypatch)
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members"):
_seed(_auth_ctx(account_id=acct_id))
body, status = api.get.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
body, status = api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
assert status == 200
assert body["page"] == 1
@ -373,7 +372,7 @@ def test_members_list_paginates_with_query_params(app, bypass_pipeline, monkeypa
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members?page=2&limit=2"):
_seed(_auth_ctx(account_id=acct_id))
body, status = api.get.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
body, status = api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
assert status == 200
assert body["page"] == 2
@ -396,7 +395,7 @@ def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypa
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members?pg=2"):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(BadRequest):
api.get.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
# ---------------------------------------------------------------------------
@ -434,7 +433,7 @@ def test_invite_happy_path_returns_invite_url_and_member_id(app, bypass_pipeline
content_type="application/json",
):
_seed(_auth_ctx(account_id=acct_id))
body, status = api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
assert status == 201
assert body["result"] == "success"
@ -519,7 +518,7 @@ def test_invite_blocked_by_saas_members_cap(app, bypass_pipeline, monkeypatch):
with _invite_request(app, ws_id, acct_id):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(Forbidden) as exc_info:
api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
body = exc_info.value.response.json
assert body["code"] == "members.limit_exceeded"
@ -565,7 +564,7 @@ def test_invite_blocked_by_ee_workspace_members_license(app, bypass_pipeline, mo
with _invite_request(app, ws_id, acct_id):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(Forbidden) as exc_info:
api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
body = exc_info.value.response.json
assert body["code"] == "workspace_members.license_exceeded"
@ -604,7 +603,7 @@ def test_invite_ce_passes_when_both_caps_disabled(app, bypass_pipeline, monkeypa
with _invite_request(app, ws_id, acct_id):
_seed(_auth_ctx(account_id=acct_id))
body, status = api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
assert status == 201
assert body["email"] == "new@example.com"
@ -633,7 +632,7 @@ def test_invite_400_when_already_in_tenant(app, bypass_pipeline, monkeypatch):
):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(BadRequest):
api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
# ---------------------------------------------------------------------------
@ -666,7 +665,7 @@ def test_delete_member_happy_path(app, bypass_pipeline, monkeypatch):
method="DELETE",
):
_seed(_auth_ctx(account_id=acct_id))
body, status = api.delete.__wrapped__(
body, status = api.delete.__wrapped__.__wrapped__(
api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id)
)
@ -708,7 +707,7 @@ def test_delete_member_exception_mapping(app, bypass_pipeline, monkeypatch, exc,
):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(expected):
api.delete.__wrapped__(
api.delete.__wrapped__.__wrapped__(
api,
workspace_id=ws_id,
member_id=member_id,
@ -735,7 +734,7 @@ def test_delete_member_404_when_member_missing(app, bypass_pipeline, monkeypatch
):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(NotFound):
api.delete.__wrapped__(
api.delete.__wrapped__.__wrapped__(
api,
workspace_id=ws_id,
member_id=member_id,
@ -775,7 +774,9 @@ def test_update_role_happy_path(app, bypass_pipeline, monkeypatch):
content_type="application/json",
):
_seed(_auth_ctx(account_id=acct_id))
body, status = api.put.__wrapped__(api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id))
body, status = api.put.__wrapped__.__wrapped__(
api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id)
)
assert status == 200
assert body == {"result": "success"}
@ -819,7 +820,7 @@ def test_update_role_exception_mapping(app, bypass_pipeline, monkeypatch, exc, e
):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(expected):
api.put.__wrapped__(
api.put.__wrapped__.__wrapped__(
api,
workspace_id=ws_id,
member_id=member_id,
@ -827,6 +828,44 @@ def test_update_role_exception_mapping(app, bypass_pipeline, monkeypatch, exc, e
)
# ---------------------------------------------------------------------------
# Role gate composition — non-member sees 404 even with valid bearer
# ---------------------------------------------------------------------------
def test_non_member_caller_gets_404_on_switch(app, bypass_pipeline, monkeypatch):
"""End-to-end: caller has valid account bearer but no membership in
the requested workspace. The role gate must short-circuit to 404
before any TenantService method is touched."""
ws_id = str(uuid.uuid4())
acct_id = uuid.uuid4()
api = WorkspaceSwitchApi()
mock_db = MagicMock()
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
switch_mock = Mock()
monkeypatch.setattr(
sys.modules["controllers.openapi.workspaces"],
"TenantService",
_tenant_service(switch_tenant=switch_mock),
)
monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db)
monkeypatch.setattr(sys.modules["controllers.openapi.auth.role_gate"], "db", mock_db)
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/switch", method="POST"):
_seed(_auth_ctx(account_id=acct_id))
# Strip only the bearer + surface-gate wrappers; keep the role gate.
# Decorator stack (innermost → outermost):
# role_gate → accept_subjects → validate_bearer
# `post.__wrapped__` is now the role-gate wrapper directly (auth_router.guard is the only outer wrapper).
gated = api.post.__wrapped__
with pytest.raises(NotFound):
gated(api, workspace_id=ws_id)
switch_mock.assert_not_called()
# ---------------------------------------------------------------------------
# _load_tenant rejects archived tenant
# ---------------------------------------------------------------------------
@ -852,7 +891,7 @@ def test_load_tenant_rejects_archived_workspace(app, bypass_pipeline, monkeypatc
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members"):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(NotFound):
api.get.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
# ---------------------------------------------------------------------------
@ -886,4 +925,4 @@ def test_invite_400_when_register_error(app, bypass_pipeline, monkeypatch):
):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(BadRequest):
api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))

View File

@ -638,8 +638,8 @@ class TestTenantService:
callable_func(*args, **kwargs)
# ==================== get_account_role_in_tenant Tests ====================
# Backs the auth pipeline's `load_workspace_role`: None => non-member
# (pipeline maps to 404), otherwise the caller's role (out-of-set role => 403).
# Backs `require_workspace_role`: None => non-member (gate maps to 404),
# otherwise the caller's role (gate maps an out-of-set role to 403).
def test_get_account_role_in_tenant_returns_role_for_member(self):
"""A row in TenantAccountJoin yields the caller's role."""

1
api/uv.lock generated
View File

@ -1300,7 +1300,6 @@ requires-dist = [
{ name = "pydantic-ai-slim", extras = ["anthropic", "google", "openai"], marker = "extra == 'server'", specifier = ">=1.85.1,<2.0.0" },
{ name = "pydantic-settings", marker = "extra == 'server'", specifier = ">=2.12.0,<3.0.0" },
{ name = "redis", marker = "extra == 'server'", specifier = ">=7.4.0,<8.0.0" },
{ name = "shell-session-manager", marker = "extra == 'server'", specifier = "==2.1.1" },
{ name = "typing-extensions", specifier = ">=4.12.2,<5.0.0" },
{ name = "uvicorn", extras = ["standard"], marker = "extra == 'server'", specifier = "==0.46.0" },
]

6
cli/.gitignore vendored
View File

@ -4,4 +4,8 @@ node_modules/
*.tsbuildinfo
.vitest-cache/
docs/specs/
context/
context/
test/**/*.ts.map
test/**/*.js.map
test/**/*.js
test/**/*.d.ts

View File

@ -2,7 +2,7 @@ import type { StubServer } from '@test/fixtures/stub-server'
import { testHttpClient } from '@test/fixtures/http-client'
import { jsonResponder, startStubServer } from '@test/fixtures/stub-server'
import { afterEach, describe, expect, it } from 'vitest'
import { isBaseError } from '@/errors/base'
import { isHttpClientError } from '@/errors/base'
import { AccountSessionsClient } from './account-sessions.js'
const LIST_BODY = { page: 1, limit: 100, total: 0, has_more: false, data: [] }
@ -70,7 +70,7 @@ describe('AccountSessionsClient.revoke', () => {
jsonResponder(404, { error: { code: 'not_found', message: 'session not found' } }, cap))
await expect(makeClient(stub.url).revoke('missing')).rejects.toSatisfy(
err => isBaseError(err) && err.httpStatus === 404,
err => isHttpClientError(err) && err.httpStatus === 404,
)
})

View File

@ -2,7 +2,7 @@ import type { StubServer } from '@test/fixtures/stub-server'
import { testHttpClient } from '@test/fixtures/http-client'
import { jsonResponder, startStubServer } from '@test/fixtures/stub-server'
import { afterEach, describe, expect, it } from 'vitest'
import { isBaseError } from '@/errors/base'
import { isHttpClientError } from '@/errors/base'
import { AccountClient } from './account.js'
function makeClient(host: string): AccountClient {
@ -35,7 +35,7 @@ describe('AccountClient.get', () => {
stub = await startStubServer(cap => jsonResponder(401, { error: 'expired' }, cap))
await expect(makeClient(stub.url).get()).rejects.toSatisfy(
err => isBaseError(err) && err.httpStatus === 401,
err => isHttpClientError(err) && err.httpStatus === 401,
)
})
})

View File

@ -2,7 +2,7 @@ import type { StubServer } from '@test/fixtures/stub-server'
import { testHttpClient } from '@test/fixtures/http-client'
import { jsonResponder, startStubServer } from '@test/fixtures/stub-server'
import { afterEach, describe, expect, it } from 'vitest'
import { isBaseError } from '@/errors/base'
import { isHttpClientError } from '@/errors/base'
import { AppsClient } from './apps.js'
const LIST_BODY = { page: 1, limit: 20, total: 0, has_more: false, data: [] }
@ -74,7 +74,7 @@ describe('AppsClient.list', () => {
stub = await startStubServer(cap => jsonResponder(403, { error: 'forbidden' }, cap))
await expect(makeClient(stub.url).list({ workspaceId: 'ws-1' })).rejects.toSatisfy(
err => isBaseError(err) && err.httpStatus === 403,
err => isHttpClientError(err) && err.httpStatus === 403,
)
})
})

View File

@ -5,7 +5,7 @@ import { join } from 'node:path'
import { testHttpClient } from '@test/fixtures/http-client'
import { jsonResponder, startStubServer } from '@test/fixtures/stub-server'
import { afterEach, beforeEach, describe, expect, it } from 'vitest'
import { isBaseError } from '@/errors/base'
import { isHttpClientError } from '@/errors/base'
import { FileUploadClient } from './file-upload.js'
const UPLOADED = {
@ -70,7 +70,7 @@ describe('FileUploadClient.upload', () => {
stub = await startStubServer(cap => jsonResponder(413, { error: 'file too large' }, cap))
await expect(makeClient(stub.url).upload('app-1', filePath)).rejects.toSatisfy(
err => isBaseError(err) && err.httpStatus === 413,
err => isHttpClientError(err) && err.httpStatus === 413,
)
})
})

View File

@ -2,7 +2,7 @@ import type { StubServer } from '@test/fixtures/stub-server'
import { testHttpClient } from '@test/fixtures/http-client'
import { jsonResponder, startStubServer } from '@test/fixtures/stub-server'
import { afterEach, describe, expect, it } from 'vitest'
import { isBaseError } from '@/errors/base'
import { isHttpClientError } from '@/errors/base'
import { MembersClient } from './members.js'
import { WorkspacesClient } from './workspaces.js'
@ -62,7 +62,7 @@ describe('MembersClient.list', () => {
stub = await startStubServer(cap => jsonResponder(403, { error: 'forbidden' }, cap))
await expect(makeClient(stub.url).list('ws-1')).rejects.toSatisfy(
err => isBaseError(err) && err.httpStatus === 403,
err => isHttpClientError(err) && err.httpStatus === 403,
)
})
@ -70,7 +70,7 @@ describe('MembersClient.list', () => {
stub = await startStubServer(cap => jsonResponder(404, { error: 'not found' }, cap))
await expect(makeClient(stub.url).list('ws-missing')).rejects.toSatisfy(
err => isBaseError(err) && err.httpStatus === 404,
err => isHttpClientError(err) && err.httpStatus === 404,
)
})
})
@ -117,7 +117,7 @@ describe('MembersClient.invite', () => {
await expect(
makeClient(stub.url).invite('ws-1', { email: 'u@e.com', role: 'normal' }),
).rejects.toSatisfy(err => isBaseError(err) && err.httpStatus === 400)
).rejects.toSatisfy(err => isHttpClientError(err) && err.httpStatus === 400)
})
})
@ -142,7 +142,7 @@ describe('MembersClient.remove', () => {
stub = await startStubServer(cap => jsonResponder(400, { error: 'cannot operate self' }, cap))
await expect(makeClient(stub.url).remove('ws-1', 'm-1')).rejects.toSatisfy(
err => isBaseError(err) && err.httpStatus === 400,
err => isHttpClientError(err) && err.httpStatus === 400,
)
})
})
@ -170,7 +170,7 @@ describe('MembersClient.updateRole', () => {
await expect(
makeClient(stub.url).updateRole('ws-1', 'm-1', { role: 'admin' }),
).rejects.toSatisfy(err => isBaseError(err) && err.httpStatus === 400)
).rejects.toSatisfy(err => isHttpClientError(err) && err.httpStatus === 400)
})
})
@ -209,7 +209,7 @@ describe('WorkspacesClient.switch (integration with stub)', () => {
const client = new WorkspacesClient(testHttpClient(stub.url, 'dfoa_test'))
await expect(client.switch('ws-x')).rejects.toSatisfy(
err => isBaseError(err) && err.httpStatus === 404,
err => isHttpClientError(err) && err.httpStatus === 404,
)
})
})

View File

@ -1,5 +1,5 @@
import type { HttpClient } from '@/http/types'
import { BaseError } from '@/errors/base'
import { BaseError, HttpClientError } from '@/errors/base'
import { ErrorCode } from '@/errors/codes'
export const DEFAULT_CLIENT_ID = 'difyctl'
@ -80,7 +80,7 @@ export class DeviceFlowApi {
if (res.status === 404)
throw versionSkew()
if (!res.ok) {
throw new BaseError({
throw new HttpClientError({
code: ErrorCode.Server4xxOther,
message: `device/code: HTTP ${res.status}`,
httpStatus: res.status,
@ -133,8 +133,8 @@ export class DeviceFlowApi {
}
}
function versionSkew(): BaseError {
return new BaseError({
function versionSkew(): HttpClientError {
return new HttpClientError({
code: ErrorCode.UnsupportedEndpoint,
message: 'this Dify host does not implement the OAuth device flow',
httpStatus: 404,

View File

@ -2,7 +2,7 @@ import type { StubServer } from '@test/fixtures/stub-server'
import { testHttpClient } from '@test/fixtures/http-client'
import { jsonResponder, startStubServer } from '@test/fixtures/stub-server'
import { afterEach, describe, expect, it } from 'vitest'
import { isBaseError } from '@/errors/base'
import { isHttpClientError } from '@/errors/base'
import { WorkspacesClient } from './workspaces.js'
// WorkspacesClient.switch is covered in members.test.ts; this file covers list().
@ -30,14 +30,14 @@ describe('WorkspacesClient.list', () => {
expect(stub.captured.method).toBe('GET')
expect(stub.captured.url).toBe('/openapi/v1/workspaces')
expect(res.workspaces[0].id).toBe('ws-1')
expect(res.workspaces[0]?.id).toBe('ws-1')
})
it('maps 401 to a classified BaseError', async () => {
stub = await startStubServer(cap => jsonResponder(401, { error: 'expired' }, cap))
await expect(makeClient(stub.url).list()).rejects.toSatisfy(
err => isBaseError(err) && err.httpStatus === 401,
err => isHttpClientError(err) && err.httpStatus === 401,
)
})
})

View File

@ -357,14 +357,14 @@ describe('runApp', () => {
// warm cache with successful run
await runApp(
{ appId: 'app-1', message: 'hi' },
{ bundle: bundle(), http: testHttpClient(mock.url, 'dfoa_test'), host: mock.url, io, cache },
{ active: active(), http: testHttpClient(mock.url, 'dfoa_test'), host: mock.url, io, cache },
)
expect(cache.get(mock.url, 'app-1')).toBeDefined()
mock.setScenario('run-422-stale')
const err = await runApp(
{ appId: 'app-1', message: 'hi' },
{ bundle: bundle(), http: testHttpClient(mock.url, { bearer: 'dfoa_test', retryAttempts: 0 }), host: mock.url, io, cache },
{ active: active(), http: testHttpClient(mock.url, { bearer: 'dfoa_test', retryAttempts: 0 }), host: mock.url, io, cache },
).catch((e: unknown) => e)
expect(err).toMatchObject({ code: 'server_4xx_other', httpStatus: 422 })
expect((err as { hint?: string }).hint).toMatch(/cache cleared/)

View File

@ -7,7 +7,7 @@ import { AppRunClient } from '@/api/app-run'
import { AppsClient } from '@/api/apps'
import { FileUploadClient } from '@/api/file-upload'
import { pickStrategy } from '@/commands/run/app/_strategies/index'
import { BaseError } from '@/errors/base'
import { BaseError, HttpClientError } from '@/errors/base'
import { ErrorCode } from '@/errors/codes'
import { getEnv, processExit } from '@/sys/index'
import { FieldInfo } from '@/types/app-meta'
@ -87,7 +87,7 @@ export async function runApp(opts: RunAppOptions, deps: RunAppDeps): Promise<voi
await executeRun(opts, deps, meta, wsId)
}
catch (err) {
if (err instanceof BaseError && err.httpStatus === 422) {
if (err instanceof HttpClientError && err.httpStatus === 422) {
await meta.invalidate(opts.appId)
throw err.withHint('app metadata cache cleared — if the app was recently republished, run the command again')
}

View File

@ -1,3 +1,4 @@
import type { HttpClientError } from '@/errors/base'
import type { SseEvent } from '@/http/sse'
import { describe, expect, it } from 'vitest'
import { collect, collectorFor, decodeStreamError, HitlPauseError } from './sse-collector'
@ -130,7 +131,7 @@ describe('decodeStreamError', () => {
const err = decodeStreamError(enc.encode(JSON.stringify(env)))
expect(err.message).toBe(inner.args.description)
expect(err.code).toBe('server_4xx_other')
expect(err.httpStatus).toBe(400)
expect((err as HttpClientError).httpStatus).toBe(400)
})
it('unwraps openapi-v1 invoke-error: falls back to inner.message when no args.description', () => {

View File

@ -1,6 +1,6 @@
import type { BaseError } from '@/errors/base'
import type { SseEvent } from '@/http/sse'
import { newError } from '@/errors/base'
import { HttpClientError, newError } from '@/errors/base'
import { ErrorCode } from '@/errors/codes'
import { RUN_MODES } from './handlers'
@ -173,9 +173,9 @@ export function decodeStreamError(data: Uint8Array): BaseError {
const code = env.status !== undefined && env.status > 0 && env.status < 500
? ErrorCode.Server4xxOther
: ErrorCode.Server5xx
let err = newError(code, message)
const err = newError(code, message)
if (env.status !== undefined && env.status > 0)
err = err.withHttpStatus(env.status)
return HttpClientError.from(err).withHttpStatus(env.status)
return err
}

View File

@ -1,10 +1,10 @@
import { describe, expect, it } from 'vitest'
import { BaseError, isBaseError, newError, unknownError } from './base'
import { BaseError, HttpClientError, isBaseError, newError, unknownError } from './base'
import { ErrorCode, ExitCode } from './codes'
describe('BaseError', () => {
it('captures code, message, optional fields', () => {
const err = new BaseError({
const err = new HttpClientError({
code: ErrorCode.AuthExpired,
message: 'session expired',
hint: 'run difyctl auth login',
@ -30,7 +30,6 @@ describe('BaseError', () => {
expect(newError(ErrorCode.AuthExpired, 'x').exit()).toBe(ExitCode.Auth)
expect(newError(ErrorCode.UsageInvalidFlag, 'x').exit()).toBe(ExitCode.Usage)
expect(newError(ErrorCode.VersionSkew, 'x').exit()).toBe(ExitCode.VersionCompat)
expect(newError(ErrorCode.NetworkDns, 'x').exit()).toBe(ExitCode.Generic)
})
it('toString without hint formats "<code>: <message>"', () => {
@ -56,7 +55,7 @@ describe('BaseError', () => {
it('withHttpStatus + withRequest + wrap chain immutably', () => {
const cause = new Error('underlying')
const built = newError(ErrorCode.NetworkTimeout, 'timed out')
const built = HttpClientError.from(newError(ErrorCode.NetworkConnection, 'timed out'))
.withHttpStatus(504)
.withRequest('POST', 'https://x/y')
.wrap(cause)
@ -68,7 +67,7 @@ describe('BaseError', () => {
it('wrap exposes cause via standard Error.cause property', () => {
const cause = new Error('underlying failure')
const wrapped = newError(ErrorCode.NetworkTimeout, 'timed out').wrap(cause)
const wrapped = newError(ErrorCode.NetworkConnection, 'timed out').wrap(cause)
expect(wrapped.cause).toBe(cause)
})
@ -86,3 +85,56 @@ describe('BaseError', () => {
expect(err.cause).toBe(cause)
})
})
describe('error envelope', () => {
it('emits required fields only when minimal', () => {
const err = newError(ErrorCode.Unknown, 'boom')
expect(err.toEnvelope()).toEqual({
error: { code: 'unknown', message: 'boom' },
})
})
it('includes hint / http_status / method / url when present', () => {
const err = HttpClientError.from(newError(ErrorCode.NetworkConnection, 'timed out'))
.withHint('check your network')
.withHttpStatus(504)
.withRequest('POST', 'https://api.dify.ai/v1/x')
expect(err.toEnvelope()).toEqual({
error: {
code: 'network_connection',
message: 'timed out',
hint: 'check your network',
http_status: 504,
method: 'POST',
url: 'https://api.dify.ai/v1/x',
},
})
})
it('renderEnvelope returns a single-line JSON string', () => {
const err = newError(ErrorCode.AuthExpired, 'session expired')
.withHint('run difyctl auth login')
const out = JSON.stringify(err.toEnvelope())
expect(out).toBe(
'{"error":{"code":"auth_expired","message":"session expired","hint":"run difyctl auth login"}}',
)
expect(out).not.toContain('\n')
})
it('renderEnvelope output round-trips through JSON.parse to an ErrorEnvelope shape', () => {
const err = newError(ErrorCode.UsageInvalidFlag, 'bad flag').withHint('see --help')
const parsed = JSON.parse(JSON.stringify(err.toEnvelope()))
expect(parsed).toEqual({
error: { code: 'usage_invalid_flag', message: 'bad flag', hint: 'see --help' },
})
})
it('omits undefined optional fields entirely (no `hint: null`)', () => {
const err = newError(ErrorCode.Server5xx, 'upstream broke')
const envelope = err.toEnvelope()
expect(envelope.error).not.toHaveProperty('hint')
expect(envelope.error).not.toHaveProperty('http_status')
expect(envelope.error).not.toHaveProperty('method')
expect(envelope.error).not.toHaveProperty('url')
})
})

View File

@ -1,31 +1,24 @@
import type { ErrorCodeValue, ExitCodeValue } from './codes'
import type { ErrorEnvelope, PrintableError } from './format'
import { ErrorCode, exitFor } from './codes'
export type BaseErrorOptions = {
readonly code: ErrorCodeValue
readonly message: string
readonly hint?: string
readonly httpStatus?: number
readonly method?: string
readonly url?: string
readonly cause?: unknown
}
export class BaseError extends Error {
export class BaseError extends Error implements PrintableError {
readonly code: ErrorCodeValue
readonly hint?: string
readonly httpStatus?: number
readonly method?: string
readonly url?: string
constructor(opts: BaseErrorOptions) {
super(opts.message, opts.cause === undefined ? undefined : { cause: opts.cause })
this.name = 'BaseError'
this.code = opts.code
this.hint = opts.hint
this.httpStatus = opts.httpStatus
this.method = opts.method
this.url = opts.url
Object.setPrototypeOf(this, new.target.prototype)
}
@ -39,30 +32,31 @@ export class BaseError extends Error {
: `${this.code}: ${this.message}`
}
withHint(hint: string): BaseError {
return new BaseError({ ...this.snapshot(), hint })
toEnvelope(): ErrorEnvelope {
const payload: ErrorEnvelope['error'] = {
code: this.code,
message: this.message,
}
if (this.hint !== undefined)
payload.hint = this.hint
return { error: payload }
}
withHttpStatus(httpStatus: number): BaseError {
return new BaseError({ ...this.snapshot(), httpStatus })
withHint<T extends BaseError>(this: T, hint: string): T {
const Ctor = this.constructor as new (opts: BaseErrorOptions) => T
return new Ctor({ ...this.snapshot(), hint })
}
withRequest(method: string, url: string): BaseError {
return new BaseError({ ...this.snapshot(), method, url })
wrap<T extends BaseError>(this: T, cause: unknown): T {
const Ctor = this.constructor as new (opts: BaseErrorOptions) => T
return new Ctor({ ...this.snapshot(), cause })
}
wrap(cause: unknown): BaseError {
return new BaseError({ ...this.snapshot(), cause })
}
private snapshot(): BaseErrorOptions {
protected snapshot(): BaseErrorOptions {
return {
code: this.code,
message: this.message,
hint: this.hint,
httpStatus: this.httpStatus,
method: this.method,
url: this.url,
cause: this.cause,
}
}
@ -76,6 +70,79 @@ export function isBaseError(value: unknown): value is BaseError {
return value instanceof BaseError
}
export function isHttpClientError(value: unknown): value is HttpClientError {
return value instanceof HttpClientError
}
export function unknownError(message: string, cause?: unknown): BaseError {
return new BaseError({ code: ErrorCode.Unknown, message, cause })
}
type HttpClientErrorOptions = BaseErrorOptions & {
readonly httpStatus?: number
readonly method?: string
readonly url?: string
readonly rawResponse?: string
}
export class HttpClientError extends BaseError {
readonly httpStatus?: number
readonly method?: string
readonly url?: string
readonly rawResponse?: string
constructor(opts: HttpClientErrorOptions) {
super(opts)
this.httpStatus = opts.httpStatus
this.method = opts.method
this.url = opts.url
this.rawResponse = opts.rawResponse
}
override toEnvelope(): ErrorEnvelope {
const envelope = super.toEnvelope()
if (this.httpStatus !== undefined)
envelope.error.http_status = this.httpStatus
if (this.method !== undefined)
envelope.error.method = this.method
if (this.url !== undefined)
envelope.error.url = this.url
if (this.rawResponse !== undefined)
envelope.error.raw_response = this.rawResponse
return envelope
}
protected override snapshot(): HttpClientErrorOptions {
return {
...super.snapshot(),
httpStatus: this.httpStatus,
method: this.method,
url: this.url,
rawResponse: this.rawResponse,
}
}
public static from(error: BaseError): HttpClientError {
return new HttpClientError({
code: error.code,
message: error.message,
hint: error.hint,
cause: error.cause,
})
}
withHttpStatus(httpStatus: number): HttpClientError {
return new HttpClientError({ ...this.snapshot(), httpStatus })
}
withRequest(method: string, url: string): HttpClientError {
return new HttpClientError({ ...this.snapshot(), method, url })
}
withRawResponse(rawResponse: string): HttpClientError {
if (!rawResponse) {
return this
}
return new HttpClientError({ ...this.snapshot(), rawResponse })
}
}

View File

@ -42,8 +42,6 @@ describe('error codes', () => {
[ErrorCode.UsageMissingArg, ExitCode.Usage],
[ErrorCode.ConfigInvalidKey, ExitCode.Usage],
[ErrorCode.ConfigInvalidValue, ExitCode.Usage],
[ErrorCode.NetworkTimeout, ExitCode.Generic],
[ErrorCode.NetworkDns, ExitCode.Generic],
[ErrorCode.Server5xx, ExitCode.Generic],
[ErrorCode.Server4xxOther, ExitCode.Generic],
[ErrorCode.ClientError, ExitCode.Generic],

View File

@ -11,8 +11,7 @@ export const ErrorCode = {
UsageMissingArg: 'usage_missing_arg',
ConfigInvalidKey: 'config_invalid_key',
ConfigInvalidValue: 'config_invalid_value',
NetworkTimeout: 'network_timeout',
NetworkDns: 'network_dns',
NetworkConnection: 'network_connection',
Server5xx: 'server_5xx',
Server4xxOther: 'server_4xx_other',
ClientError: 'client_error',
@ -45,8 +44,7 @@ const CODE_TO_EXIT: Readonly<Record<ErrorCodeValue, ExitCodeValue>> = {
usage_missing_arg: ExitCode.Usage,
config_invalid_key: ExitCode.Usage,
config_invalid_value: ExitCode.Usage,
network_timeout: ExitCode.Generic,
network_dns: ExitCode.Generic,
network_connection: ExitCode.Generic,
server_5xx: ExitCode.Generic,
server_4xx_other: ExitCode.Generic,
client_error: ExitCode.Generic,

View File

@ -1,57 +0,0 @@
import { describe, expect, it } from 'vitest'
import { newError } from './base'
import { ErrorCode } from './codes'
import { renderEnvelope, toEnvelope } from './envelope'
describe('error envelope', () => {
it('emits required fields only when minimal', () => {
const err = newError(ErrorCode.Unknown, 'boom')
expect(toEnvelope(err)).toEqual({
error: { code: 'unknown', message: 'boom' },
})
})
it('includes hint / http_status / method / url when present', () => {
const err = newError(ErrorCode.NetworkTimeout, 'timed out')
.withHint('check your network')
.withHttpStatus(504)
.withRequest('POST', 'https://api.dify.ai/v1/x')
expect(toEnvelope(err)).toEqual({
error: {
code: 'network_timeout',
message: 'timed out',
hint: 'check your network',
http_status: 504,
method: 'POST',
url: 'https://api.dify.ai/v1/x',
},
})
})
it('renderEnvelope returns a single-line JSON string', () => {
const err = newError(ErrorCode.AuthExpired, 'session expired')
.withHint('run difyctl auth login')
const out = renderEnvelope(err)
expect(out).toBe(
'{"error":{"code":"auth_expired","message":"session expired","hint":"run difyctl auth login"}}',
)
expect(out).not.toContain('\n')
})
it('renderEnvelope output round-trips through JSON.parse to an ErrorEnvelope shape', () => {
const err = newError(ErrorCode.UsageInvalidFlag, 'bad flag').withHint('see --help')
const parsed = JSON.parse(renderEnvelope(err))
expect(parsed).toEqual({
error: { code: 'usage_invalid_flag', message: 'bad flag', hint: 'see --help' },
})
})
it('omits undefined optional fields entirely (no `hint: null`)', () => {
const err = newError(ErrorCode.Server5xx, 'upstream broke')
const envelope = toEnvelope(err)
expect(envelope.error).not.toHaveProperty('hint')
expect(envelope.error).not.toHaveProperty('http_status')
expect(envelope.error).not.toHaveProperty('method')
expect(envelope.error).not.toHaveProperty('url')
})
})

View File

@ -1,32 +0,0 @@
import type { BaseError } from './base'
export type ErrorEnvelope = {
error: {
code: string
message: string
hint?: string
http_status?: number
method?: string
url?: string
}
}
export function toEnvelope(err: BaseError): ErrorEnvelope {
const payload: ErrorEnvelope['error'] = {
code: err.code,
message: err.message,
}
if (err.hint !== undefined)
payload.hint = err.hint
if (err.httpStatus !== undefined)
payload.http_status = err.httpStatus
if (err.method !== undefined)
payload.method = err.method
if (err.url !== undefined)
payload.url = err.url
return { error: payload }
}
export function renderEnvelope(err: BaseError): string {
return JSON.stringify(toEnvelope(err))
}

View File

@ -1,26 +1,58 @@
import type { BaseError } from './base'
import { isVerbose } from '@/framework/context'
import { redactBearer } from '@/http/sanitize'
import { colorEnabled, colorScheme } from '@/sys/io/color'
import { renderEnvelope } from './envelope'
export type FormatErrorOptions = {
readonly format?: string
readonly isErrTTY?: boolean
}
export function formatErrorForCli(err: BaseError, opts: FormatErrorOptions = {}): string {
if (opts.format === 'json')
return renderEnvelope(err)
return humanError(err, opts.isErrTTY ?? false)
export type ErrorEnvelope = {
error: {
code: string
message: string
hint?: string
http_status?: number
method?: string
url?: string
raw_response?: string
}
}
function humanError(err: BaseError, isErrTTY: boolean): string {
export type PrintableError = {
toEnvelope: () => ErrorEnvelope
}
export function formatErrorForCli(err: PrintableError, opts: FormatErrorOptions = {}): string {
const env = err.toEnvelope()
if (opts.format === 'json')
return renderEnvelope(env)
return renderHuman(env, opts.isErrTTY ?? false)
}
function renderEnvelope(env: ErrorEnvelope): string {
const raw = env.error.raw_response
if (raw === undefined)
return JSON.stringify(env)
if (!isVerbose()) {
delete env.error.raw_response
return JSON.stringify(env)
}
env.error.raw_response = redactBearer(raw)
return JSON.stringify(env)
}
function renderHuman(env: ErrorEnvelope, isErrTTY: boolean): string {
const cs = colorScheme(colorEnabled(isErrTTY))
const lines: string[] = [`${err.code}: ${err.message}`]
if (err.hint !== undefined)
lines.push(`${cs.magenta('hint:')} ${cs.cyan(err.hint)}`)
if (err.method !== undefined && err.url !== undefined)
lines.push(`request: ${err.method} ${err.url}`)
if (err.httpStatus !== undefined)
lines.push(`http_status: ${err.httpStatus}`)
const e = env.error
const lines: string[] = [`${e.code}: ${e.message}`]
if (e.hint !== undefined)
lines.push(`${cs.magenta('hint:')} ${cs.cyan(e.hint)}`)
if (e.method !== undefined && e.url !== undefined)
lines.push(`request: ${e.method} ${e.url}`)
if (e.http_status !== undefined)
lines.push(`http_status: ${e.http_status}`)
if (isVerbose() && e.raw_response)
lines.push(`raw_response: ${redactBearer(e.raw_response)}`)
return lines.join('\n')
}

View File

@ -1,6 +1,7 @@
import type { CommandOutput } from './output'
import type { ArgDefinition, FlagDefinition, ICommand, InferArgs, InferFlags, OptionalArgValueType } from './types'
import { parseArgv } from './flags'
import { setVerbose } from './context'
import { hasBooleanFlag, parseArgv, VERBOSE_CHAR, VERBOSE_FLAG } from './flags'
export type CommandConstructor = {
new(): Command
@ -28,11 +29,16 @@ type ParseResult<C extends CommandConstructor> = {
export abstract class Command implements ICommand {
static description?: string
static flags: Record<string, FlagDefinition<OptionalArgValueType>> = {}
static args: Record<string, ArgDefinition<string | undefined>> = {}
static examples: string[] = []
abstract run(argv: string[]): Promise<CommandOutput | void>
processGlobalFlags(argv: readonly string[]): void {
setVerbose(hasBooleanFlag(argv, VERBOSE_FLAG, VERBOSE_CHAR))
}
protected parse<C extends CommandConstructor>(ctor: C, argv: string[]): ParseResult<C> {
const meta = {
flags: ctor.flags ?? {},

View File

@ -0,0 +1,15 @@
type CommandContext = {
verbose: boolean
}
const commandContext: CommandContext = {
verbose: false,
}
export function setVerbose(verbose: boolean): void {
commandContext.verbose = verbose
}
export function isVerbose(): boolean {
return commandContext.verbose
}

View File

@ -1,6 +1,24 @@
import type { ArgDefinition, CommandMeta, FlagDefinition, ParsedArgs, ParsedFlags } from './types'
import { UnsupportedArgValueError } from './errors'
export const VERBOSE_FLAG = 'verbose'
export const VERBOSE_CHAR = 'v'
export const Flags = {
string: stringFlag,
stringArray: stringRepeatedFlag,
boolean: booleanFlag,
integer: integerFlag,
outputFormat: outputFormatFlag,
}
const GLOBAL_FLAGS: Record<string, FlagDefinition> = {
[VERBOSE_FLAG]: Flags.boolean({
char: VERBOSE_CHAR,
description: 'enable verbose output',
}),
}
function stringFlag<const Opts extends {
description: string
char?: string
@ -48,14 +66,6 @@ function integerFlag<const Opts extends { description: string, char?: string, de
return { type: 'integer', ...opts } as FlagDefinition<Opts extends { default: number } ? number : number | undefined>
}
export const Flags = {
string: stringFlag,
stringArray: stringRepeatedFlag,
boolean: booleanFlag,
integer: integerFlag,
outputFormat: outputFormatFlag,
}
function stringArg<const Opts extends { description: string, required?: boolean }>(
opts: Opts,
): ArgDefinition<Opts extends { required: true } ? string : string | undefined> {
@ -99,8 +109,8 @@ function accumulateFlagValue(flags: ParsedFlags, name: string, value: string | b
}
}
function resolveByChar(char: string, meta: CommandMeta): [name: string, def: FlagDefinition] | undefined {
for (const [name, def] of Object.entries(meta.flags)) {
function resolveByChar(char: string, flags: Record<string, FlagDefinition>): [name: string, def: FlagDefinition] | undefined {
for (const [name, def] of Object.entries(flags)) {
if (def.char === char)
return [name, def]
}
@ -115,12 +125,12 @@ function validateFlagOptions(name: string, raw: string, def: FlagDefinition): vo
type ResolvedFlag = { name: string, def: FlagDefinition, label: string, inlineRaw: string | undefined }
function resolveToken(token: string, meta: CommandMeta): ResolvedFlag | null {
function resolveToken(token: string, flags: Record<string, FlagDefinition>): ResolvedFlag | null {
if (token.startsWith('--')) {
const eqIdx = token.indexOf('=')
const name = eqIdx !== -1 ? token.slice(2, eqIdx) : token.slice(2)
const inlineRaw = eqIdx !== -1 ? token.slice(eqIdx + 1) : undefined
const def = meta.flags[name]
const def = flags[name]
if (!def)
throw new Error(`unknown flag: --${name}`)
return { name, def, label: `--${name}`, inlineRaw }
@ -128,7 +138,7 @@ function resolveToken(token: string, meta: CommandMeta): ResolvedFlag | null {
if (token.length === 2 && token[1] !== undefined) {
const char = token[1]
const resolved = resolveByChar(char, meta)
const resolved = resolveByChar(char, flags)
if (!resolved)
throw new Error(`unknown flag: -${char}`)
const [name, def] = resolved
@ -138,6 +148,21 @@ function resolveToken(token: string, meta: CommandMeta): ResolvedFlag | null {
return null
}
// Scans argv for a boolean flag without throwing on unknown tokens, so it is safe
// to call before the command-specific flag set is known (e.g. global flags).
export function hasBooleanFlag(argv: readonly string[], name: string, char?: string): boolean {
for (const token of argv) {
if (token === '--')
break
if (token === `--${name}` || token === `--${name}=true` || token === `--${name}=1`)
return true
if (char !== undefined && token === `-${char}`)
return true
}
return false
}
export function parseArgv(argv: readonly string[], meta: CommandMeta): { args: ParsedArgs, flags: ParsedFlags } {
const flags: ParsedFlags = {}
const positional: string[] = []
@ -159,7 +184,10 @@ export function parseArgv(argv: readonly string[], meta: CommandMeta): { args: P
continue
}
const resolved = resolveToken(token, meta)
const resolved = resolveToken(token, {
...meta.flags,
...GLOBAL_FLAGS, // pass global flags to prevent unknown flag error
})
if (!resolved) {
positional.push(token)
continue

View File

@ -1,7 +1,7 @@
import type { CommandConstructor } from './command'
import type { CommandTree } from './registry'
import { describe, expect, it } from 'vitest'
import { BaseError, newError } from '@/errors/base'
import { BaseError, HttpClientError, newError } from '@/errors/base'
import { ErrorCode, ExitCode } from '@/errors/codes'
import { Command } from './command'
import { run, sniffOutputFormat } from './run'
@ -171,7 +171,7 @@ describe('run() catch routing', () => {
it('routes Server5xx error with http_status line and generic exit', async () => {
class Throwing extends Command {
async run(_argv: string[]) {
throw newError(ErrorCode.Server5xx, 'upstream boom').withHttpStatus(502)
throw HttpClientError.from(newError(ErrorCode.Server5xx, 'upstream boom')).withHttpStatus(502)
}
}
const result = await captureRun(makeTree(Throwing), ['cmd'])
@ -182,7 +182,7 @@ describe('run() catch routing', () => {
it('renders request line and http_status when both are present', async () => {
class Throwing extends Command {
async run(_argv: string[]) {
throw newError(ErrorCode.Server5xx, 'upstream boom')
throw HttpClientError.from(newError(ErrorCode.Server5xx, 'upstream boom'))
.withRequest('GET', 'https://api.dify.ai/v1/me')
.withHttpStatus(502)
}
@ -197,7 +197,7 @@ describe('run() catch routing', () => {
it('serializes method and url in JSON envelope', async () => {
class Throwing extends Command {
async run(_argv: string[]) {
throw newError(ErrorCode.Server4xxOther, 'not found')
throw HttpClientError.from(newError(ErrorCode.Server4xxOther, 'not found'))
.withRequest('GET', 'https://api.dify.ai/v1/apps/x')
.withHttpStatus(404)
}

View File

@ -45,7 +45,10 @@ export async function run(tree: CommandTree, argv: string[]): Promise<void> {
if (typeof Ctor.deprecated === 'string' && Ctor.deprecated.length > 0)
process.stderr.write(`deprecated: ${Ctor.deprecated}\n`)
const cmd = new Ctor()
const output = await cmd.run(argv.slice(resolved.path.length))
const commandArgv = argv.slice(resolved.path.length)
cmd.processGlobalFlags(commandArgv)
const output = await cmd.run(commandArgv)
if (output !== undefined)
process.stdout.write(stringifyOutput(output))
}

View File

@ -3,7 +3,7 @@ import type { AddressInfo } from 'node:net'
import * as http from 'node:http'
import { startMock } from '@test/fixtures/dify-mock/server'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { isBaseError } from '@/errors/base'
import { isBaseError, isHttpClientError } from '@/errors/base'
import { ErrorCode } from '@/errors/codes'
import { openAPIBase } from '@/util/host'
import { createHttpClient } from './client.js'
@ -116,8 +116,8 @@ describe('http client', () => {
await client.get('workspaces')
}
catch (err) { caught = err }
expect(isBaseError(caught)).toBe(true)
if (isBaseError(caught)) {
expect(isHttpClientError(caught)).toBe(true)
if (isHttpClientError(caught)) {
expect(caught.code).toBe(ErrorCode.AuthExpired)
expect(caught.httpStatus).toBe(401)
expect(caught.method).toBe('GET')
@ -138,8 +138,8 @@ describe('http client', () => {
await client.get('workspaces')
}
catch (err) { caught = err }
expect(isBaseError(caught)).toBe(true)
if (isBaseError(caught)) {
expect(isHttpClientError(caught)).toBe(true)
if (isHttpClientError(caught)) {
expect(caught.code).toBe(ErrorCode.Server5xx)
expect(caught.httpStatus).toBe(503)
}
@ -187,8 +187,8 @@ describe('http client', () => {
await client.get('apps/nope/describe')
}
catch (err) { caught = err }
expect(isBaseError(caught)).toBe(true)
if (isBaseError(caught))
expect(isHttpClientError(caught)).toBe(true)
if (isHttpClientError(caught))
expect(caught.code).toBe(ErrorCode.Server4xxOther)
})
@ -205,8 +205,8 @@ describe('http client', () => {
await client.get('workspaces')
}
catch (err) { caught = err }
expect(isBaseError(caught)).toBe(true)
if (isBaseError(caught))
expect(isHttpClientError(caught)).toBe(true)
if (isHttpClientError(caught))
expect(caught.httpStatus).toBe(429)
})

View File

@ -9,6 +9,7 @@ import type {
RequestOptions,
ResolvedOptions,
} from './types.js'
import { isVerbose } from '@/framework/context'
import { userAgent as defaultUserAgent } from '@/version/info'
import { buildBody } from './body.js'
import { classifyResponse } from './error-mapper.js'
@ -133,11 +134,11 @@ async function dispatch(state: ClientState, path: string, opts: RequestOptions,
await runHooks(state.hooks.onRequest, ctx)
// `dispatcher` is an undici extension to RequestInit, not in @types/node's fetch
// signature — hence the local type. Carries proxy routing when a proxy env var is set.
const init: RequestInit & { dispatcher?: unknown } = { signal }
const init: RequestInit & { dispatcher?: unknown, verbose?: boolean } = { signal }
if (state.dispatcher !== undefined)
init.dispatcher = state.dispatcher
if (isVerbose())
init.verbose = true
try {
ctx.response = await fetch(ctx.request, init)

View File

@ -1,5 +1,4 @@
import type { BaseError } from '@/errors/base'
import { newError } from '@/errors/base'
import { BaseError, HttpClientError, newError } from '@/errors/base'
import { ErrorCode } from '@/errors/codes'
import { redactBearer } from './sanitize'
@ -32,54 +31,50 @@ async function readBody(response: Response): Promise<{ raw: string, parsed?: Wir
}
export async function classifyResponse(request: Request, response: Response): Promise<BaseError> {
const { parsed } = await readBody(response.clone())
const { parsed, raw } = await readBody(response.clone())
const wire: WireFields = parsed?.error ?? parsed ?? {}
const status = response.status
const url = redactBearer(response.url || request.url)
const method = request.method
if (status === 401) {
return newError(
return HttpClientError.from(newError(
ErrorCode.AuthExpired,
wire.message ?? 'session expired or revoked',
)
))
.withHint(wire.hint ?? 'run \'difyctl auth login\' to sign in again')
.withHttpStatus(status)
.withRequest(method, url)
}
if (status >= 500) {
return newError(
return HttpClientError.from(newError(
ErrorCode.Server5xx,
wire.message ?? `server error (HTTP ${status})`,
)
))
.withHttpStatus(status)
.withRequest(method, url)
.withRawResponse(raw)
}
const err = newError(
const err = HttpClientError.from(newError(
ErrorCode.Server4xxOther,
wire.message ?? `request failed (HTTP ${status})`,
)
))
.withHttpStatus(status)
.withRequest(method, url)
.withRawResponse(raw)
return wire.hint !== undefined ? err.withHint(wire.hint) : err
}
export function classifyTransportError(err: unknown): BaseError {
const message = err instanceof Error ? err.message : String(err)
const sanitized = redactBearer(message)
if (err instanceof Error && err.name === 'TimeoutError')
return newError(ErrorCode.NetworkTimeout, 'request timed out').wrap(err)
if (err instanceof Error && err.name === 'AbortError')
return newError(ErrorCode.NetworkTimeout, 'request aborted').wrap(err)
if (sanitized.toLowerCase().includes('econnrefused'))
return newError(ErrorCode.NetworkDns, 'connection refused').wrap(err)
if (sanitized.toLowerCase().includes('enotfound'))
return newError(ErrorCode.NetworkDns, 'host lookup failed').wrap(err)
if (sanitized.toLowerCase().includes('etimedout'))
return newError(ErrorCode.NetworkTimeout, 'connection timed out').wrap(err)
return newError(ErrorCode.Unknown, sanitized).wrap(err)
if (err instanceof BaseError) {
return err
}
if (!(err instanceof Error)) {
return newError(ErrorCode.Unknown, String(err)).wrap(err)
}
const sanitized = redactBearer(err.message)
// there isn't a practical way to classify network errors reliably
return newError(ErrorCode.NetworkConnection, sanitized).wrap(err)
}

View File

@ -1,7 +1,6 @@
{
"extends": "@dify/tsconfig/node.json",
"compilerOptions": {
"rootDir": "src",
"moduleResolution": "bundler",
"paths": {
"@/*": [
@ -12,12 +11,10 @@
]
},
"types": ["node"],
"declaration": true,
"declarationMap": true,
"noEmit": false,
"noEmit": true, // we already have bundlers to handle this.
"outDir": "dist",
"sourceMap": true
},
"include": ["src/**/*.ts"],
"exclude": ["dist", "test", "node_modules", "**/*.test.ts"]
"include": ["src/**/*.ts", "test/**/*.ts"], // tests must be included for typechecking
"exclude": ["node_modules", "dist"]
}

View File

@ -256,7 +256,9 @@ class DifyShellRuntimeState(BaseModel):
raise ValueError("workspace_cwd requires a matching session_id.")
expected_workspace = _workspace_cwd(self.session_id)
if self.workspace_cwd != expected_workspace:
raise ValueError(f"workspace_cwd must equal {expected_workspace!r} for session_id {self.session_id!r}.")
raise ValueError(
f"workspace_cwd must equal {expected_workspace!r} for session_id {self.session_id!r}."
)
unknown_offset_job_ids = set(self.job_offsets) - set(self.job_ids)
if unknown_offset_job_ids:
names = ", ".join(sorted(unknown_offset_job_ids))
@ -692,12 +694,12 @@ def _workspace_mkdir_script(*, session_id: str) -> str:
of silently reusing another session's workspace.
"""
safe_session_id = _validated_session_id(session_id)
workspace_dir = f"$HOME/workspace/{safe_session_id}"
workspace_dir = f'$HOME/workspace/{safe_session_id}'
return (
'mkdir -p "$HOME/workspace"; '
f'if mkdir "{workspace_dir}"; then exit 0; fi; '
f'if [ -e "{workspace_dir}" ]; then exit {_WORKSPACE_COLLISION_EXIT_CODE}; fi; '
"exit 1"
'exit 1'
)

View File

@ -277,7 +277,6 @@ def test_shell_layer_suspend_and_resume_reuse_state_with_fresh_clients() -> None
return next(clients)
compositor = Compositor([LayerNode("shell", _shell_provider(client_factory=factory))])
async def scenario() -> None:
async with compositor.enter(configs={"shell": DifyShellLayerConfig()}) as run:
shell_layer = run.get_layer("shell", DifyShellLayer)
@ -343,10 +342,7 @@ def test_shell_layer_delete_removes_workspace_then_force_deletes_tracked_jobs_an
assert client.events[:2] == [("run", 'rm -rf -- "$HOME/workspace/abc12ff"'), ("wait", "cleanup-job")]
assert {call.job_id for call in client.delete_calls} == {"user-job", "mkdir-job", "cleanup-job"}
assert all(
client.events.index(("delete", call.job_id)) > client.events.index(("wait", "cleanup-job"))
for call in client.delete_calls
)
assert all(client.events.index(("delete", call.job_id)) > client.events.index(("wait", "cleanup-job")) for call in client.delete_calls)
assert all(call.force is True for call in client.delete_calls)
assert layer.runtime_state.job_ids == []
assert layer.runtime_state.job_offsets == {}

View File

@ -27,9 +27,7 @@ def test_default_layer_providers_register_shell_layer_with_configured_token_fact
return factory
monkeypatch.setattr(
compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory
)
monkeypatch.setattr(compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory)
providers = create_default_layer_providers(
shellctl_entrypoint="http://shellctl.example",
@ -58,9 +56,7 @@ def test_default_layer_providers_keep_empty_shellctl_token_by_default(
return factory
monkeypatch.setattr(
compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory
)
monkeypatch.setattr(compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory)
providers = create_default_layer_providers(shellctl_entrypoint="http://shellctl.example")
shell_provider = next(provider for provider in providers if provider.type_id == DIFY_SHELL_LAYER_TYPE_ID)

View File

@ -684,8 +684,7 @@ def test_runner_rejects_duplicate_tool_names_between_shell_and_other_layers(
),
)
layer_providers = tuple(
provider
for provider in create_default_layer_providers(shellctl_entrypoint="http://unused")
provider for provider in create_default_layer_providers(shellctl_entrypoint="http://unused")
if provider.type_id != DIFY_SHELL_LAYER_TYPE_ID
) + (shell_provider,)

View File

@ -13,6 +13,7 @@
# Core service URLs
CONSOLE_API_URL=
SERVER_CONSOLE_API_URL=http://api:5001
CONSOLE_WEB_URL=
SERVICE_API_URL=
TRIGGER_URL=http://localhost

View File

@ -376,6 +376,7 @@ services:
- ./.env
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
SERVER_CONSOLE_API_URL: ${SERVER_CONSOLE_API_URL:-http://api:5001}
APP_API_URL: ${APP_API_URL:-}
AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-}
NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-}

View File

@ -382,6 +382,7 @@ services:
- ./.env
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
SERVER_CONSOLE_API_URL: ${SERVER_CONSOLE_API_URL:-http://api:5001}
APP_API_URL: ${APP_API_URL:-}
AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-}
NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-}

View File

@ -3406,14 +3406,6 @@
"count": 1
}
},
"web/app/components/workflow/block-selector/rag-tool-recommendations/index.tsx": {
"no-restricted-properties": {
"count": 3
},
"react/set-state-in-effect": {
"count": 1
}
},
"web/app/components/workflow/block-selector/tool-picker.tsx": {
"no-restricted-imports": {
"count": 1
@ -4405,11 +4397,6 @@
"count": 9
}
},
"web/app/components/workflow/nodes/question-classifier/components/class-list.tsx": {
"no-restricted-properties": {
"count": 2
}
},
"web/app/components/workflow/nodes/question-classifier/use-single-run-form-params.ts": {
"ts/no-explicit-any": {
"count": 8

View File

@ -4,14 +4,14 @@ import type { ViewType } from '@/app/components/workflow/block-selector/view-typ
import type { OnSelectBlock } from '@/app/components/workflow/types'
import { RiMoreLine } from '@remixicon/react'
import * as React from 'react'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useCallback, useMemo } from 'react'
import { Trans, useTranslation } from 'react-i18next'
import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/arrows'
import Loading from '@/app/components/base/loading'
import { getFormattedPlugin } from '@/app/components/plugins/marketplace/utils'
import { useLocalStorage } from '@/hooks/use-local-storage'
import Link from '@/next/link'
import { useRAGRecommendedPlugins } from '@/service/use-tools'
import { isServer } from '@/utils/client'
import { getMarketplaceUrl } from '@/utils/var'
import List from './list'
@ -29,26 +29,7 @@ const RAGToolRecommendations = ({
onTagsChange,
}: RAGToolRecommendationsProps) => {
const { t } = useTranslation()
const [isCollapsed, setIsCollapsed] = useState<boolean>(() => {
if (isServer)
return false
const stored = window.localStorage.getItem(STORAGE_KEY)
return stored === 'true'
})
useEffect(() => {
if (isServer)
return
const stored = window.localStorage.getItem(STORAGE_KEY)
if (stored !== null)
setIsCollapsed(stored === 'true')
}, [])
useEffect(() => {
if (isServer)
return
window.localStorage.setItem(STORAGE_KEY, String(isCollapsed))
}, [isCollapsed])
const [isCollapsed, setIsCollapsed] = useLocalStorage<boolean>(STORAGE_KEY, false)
const {
data: ragRecommendedPlugins,

View File

@ -11,6 +11,7 @@ import { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { ReactSortable } from 'react-sortablejs'
import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/general'
import { useLocalStorage } from '@/hooks/use-local-storage'
import { useEdgesInteractions } from '../../../hooks'
import AddButton from '../../_base/components/add-button'
import Item from './class-item'
@ -42,17 +43,8 @@ const ClassList: FC<Props> = ({
const [shouldScrollToEnd, setShouldScrollToEnd] = useState(false)
const prevListLength = useRef(list.length)
const [collapsed, setCollapsed] = useState(false)
const [isRenameHintDismissed, setIsRenameHintDismissed] = useState(() => {
if (typeof window === 'undefined')
return true
try {
return window.localStorage.getItem(INLINE_LABEL_HINT_STORAGE_KEY) === 'true'
}
catch {
return false
}
})
const [storedRenameHintDismissed, setIsRenameHintDismissed] = useLocalStorage<boolean>(INLINE_LABEL_HINT_STORAGE_KEY)
const isRenameHintDismissed = storedRenameHintDismissed ?? false
const handleClassChange = useCallback((index: number) => {
return (value: Topic) => {
@ -104,12 +96,7 @@ const ClassList: FC<Props> = ({
return
setIsRenameHintDismissed(true)
try {
window.localStorage.setItem(INLINE_LABEL_HINT_STORAGE_KEY, 'true')
}
catch {
}
}, [isRenameHintDismissed])
}, [isRenameHintDismissed, setIsRenameHintDismissed])
const shouldShowRenameHint = !readonly && !isRenameHintDismissed && list.some((item, index) => {
return isDefaultClassLabel(item.label, index + 1, t)

View File

@ -0,0 +1,43 @@
// @vitest-environment node
import { afterEach, describe, expect, it, vi } from 'vitest'
vi.mock('server-only', () => ({}))
const importServerConfig = async () => {
vi.resetModules()
return import('../server')
}
describe('server config', () => {
afterEach(() => {
vi.unstubAllEnvs()
})
it('should prefer the server-only console API URL for server requests', async () => {
vi.stubEnv('SERVER_CONSOLE_API_URL', 'http://api:5001')
vi.stubEnv('CONSOLE_API_URL', 'https://console.example.com')
const { SERVER_CONSOLE_API_PREFIX } = await importServerConfig()
expect(SERVER_CONSOLE_API_PREFIX).toBe('http://api:5001/console/api')
})
it('should fall back to the public console API URL when no server-only URL is configured', async () => {
vi.stubEnv('SERVER_CONSOLE_API_URL', '')
vi.stubEnv('CONSOLE_API_URL', 'https://console.example.com')
const { SERVER_CONSOLE_API_PREFIX } = await importServerConfig()
expect(SERVER_CONSOLE_API_PREFIX).toBe('https://console.example.com/console/api')
})
it('should remain unconfigured when both server URLs are empty', async () => {
vi.stubEnv('SERVER_CONSOLE_API_URL', '')
vi.stubEnv('CONSOLE_API_URL', '')
const { SERVER_CONSOLE_API_PREFIX } = await importServerConfig()
expect(SERVER_CONSOLE_API_PREFIX).toBeUndefined()
})
})

View File

@ -5,6 +5,8 @@ import 'server-only'
const withoutTrailingSlash = (value: string) => value.endsWith('/') ? value.slice(0, -1) : value
// Server-side requests need the origin; browser requests should keep using NEXT_PUBLIC_API_PREFIX.
export const SERVER_CONSOLE_API_PREFIX = env.CONSOLE_API_URL
? `${withoutTrailingSlash(env.CONSOLE_API_URL)}/console/api`
const serverConsoleApiUrl = env.SERVER_CONSOLE_API_URL || env.CONSOLE_API_URL
export const SERVER_CONSOLE_API_PREFIX = serverConsoleApiUrl
? `${withoutTrailingSlash(serverConsoleApiUrl)}/console/api`
: undefined

View File

@ -161,6 +161,7 @@ const clientSchema = {
export const env = createEnv({
server: {
CONSOLE_API_URL: z.string().optional(),
SERVER_CONSOLE_API_URL: z.string().optional(),
/**
* Maximum length of segmentation tokens for indexing
*/