mirror of
https://github.com/langgenius/dify.git
synced 2026-06-10 02:16:51 +08:00
Compare commits
7 Commits
fix/wta-47
...
cli-reques
| Author | SHA1 | Date | |
|---|---|---|---|
| e429b97245 | |||
| c5bf751877 | |||
| 1cb9fdc194 | |||
| e3fd4bd98e | |||
| 194de58615 | |||
| bb9c8272ac | |||
| 1ce11a45bf |
81
api/controllers/openapi/_contract.py
Normal file
81
api/controllers/openapi/_contract.py
Normal file
@ -0,0 +1,81 @@
|
||||
"""Request/response contract decorators for the openapi controllers.
|
||||
|
||||
``@accepts`` and ``@returns`` own one slice of the contract from a single model
|
||||
reference — emitting the Swagger schema AND doing the runtime validation/
|
||||
serialisation — so the advertised and enforced contracts can't drift. Validation
|
||||
failures map to a single shape: 422.
|
||||
|
||||
They must sit BELOW ``@auth_router.guard`` so auth runs before validation and the
|
||||
``view.__wrapped__`` unit-test seam unwraps exactly the guard layer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import abort
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from controllers.common.schema import query_params_from_model, query_params_from_request
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
|
||||
def accepts(*, query: type[BaseModel] | None = None, body: type[BaseModel] | None = None) -> Callable:
|
||||
"""Validate ``query``/``body`` against the models and inject them as keyword-only kwargs.
|
||||
|
||||
Emits the matching Swagger schema from the same models, so doc and enforcement
|
||||
stay in lockstep.
|
||||
"""
|
||||
|
||||
def decorator(view: Callable) -> Callable:
|
||||
@wraps(view)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
if query is not None:
|
||||
kwargs["query"] = query_params_from_request(query)
|
||||
if body is not None:
|
||||
kwargs["body"] = body.model_validate(request.get_json(silent=True) or {})
|
||||
except ValidationError as exc:
|
||||
# Sanitized 422 — no pydantic `url` (version) or `input` (user payload) leak.
|
||||
abort(
|
||||
422,
|
||||
message="Request validation failed",
|
||||
errors=exc.errors(include_url=False, include_input=False, include_context=False),
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
if query is not None:
|
||||
openapi_ns.doc(params=query_params_from_model(query))(wrapper)
|
||||
if body is not None:
|
||||
openapi_ns.expect(openapi_ns.models[body.__name__])(wrapper)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def returns(code: int, model: type[BaseModel], description: str | None = None) -> Callable:
|
||||
"""Serialise the handler's returned model and emit the response schema.
|
||||
|
||||
Accepts a ``BaseModel`` (serialised with ``code``) or a ``(model, status[, headers])``
|
||||
tuple (status/headers honoured). Other returns — a bare ``(dict, status)``, an SSE
|
||||
``Response`` — pass through untouched.
|
||||
"""
|
||||
|
||||
def decorator(view: Callable) -> Callable:
|
||||
@wraps(view)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
result = view(*args, **kwargs)
|
||||
if isinstance(result, BaseModel):
|
||||
return result.model_dump(mode="json"), code
|
||||
if isinstance(result, tuple) and result and isinstance(result[0], BaseModel):
|
||||
payload, *rest = result
|
||||
return (payload.model_dump(mode="json"), *rest)
|
||||
return result
|
||||
|
||||
openapi_ns.response(code, description or model.__name__, openapi_ns.models[model.__name__])(wrapper)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@ -9,15 +9,16 @@ from flask_restx import Resource
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import returns
|
||||
from controllers.openapi._models import ServerVersionResponse
|
||||
|
||||
|
||||
@openapi_ns.route("/_version")
|
||||
class VersionApi(Resource):
|
||||
@openapi_ns.response(200, "Server version", openapi_ns.models[ServerVersionResponse.__name__])
|
||||
@returns(200, ServerVersionResponse, description="Server version")
|
||||
def get(self):
|
||||
edition = dify_config.EDITION if dify_config.EDITION in ("SELF_HOSTED", "CLOUD") else "SELF_HOSTED"
|
||||
return ServerVersionResponse(
|
||||
version=dify_config.project.version,
|
||||
edition=edition,
|
||||
).model_dump(mode="json")
|
||||
)
|
||||
|
||||
@ -2,17 +2,14 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import NotFound, UnprocessableEntity
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import (
|
||||
AccountPayload,
|
||||
AccountResponse,
|
||||
PaginationEnvelope,
|
||||
RevokeResponse,
|
||||
SessionListQuery,
|
||||
SessionListResponse,
|
||||
@ -42,8 +39,8 @@ from services.oauth_device_flow import (
|
||||
|
||||
@openapi_ns.route("/account")
|
||||
class AccountApi(Resource):
|
||||
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@returns(200, AccountResponse, description="Account info")
|
||||
def get(self, *, auth_data: AuthData):
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{auth_data.account_id}")
|
||||
|
||||
@ -58,31 +55,27 @@ class AccountApi(Resource):
|
||||
account=_account_payload(account) if account else None,
|
||||
workspaces=[_workspace_payload(m) for m in memberships],
|
||||
default_workspace_id=default_ws_id,
|
||||
).model_dump(mode="json")
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/self")
|
||||
class AccountSessionsSelfApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@returns(200, RevokeResponse, description="Session revoked")
|
||||
def delete(self, *, auth_data: AuthData):
|
||||
revoke_oauth_token(db.session, redis_client, str(auth_data.token_id))
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
return RevokeResponse(status="revoked")
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions")
|
||||
class AccountSessionsApi(Resource):
|
||||
@openapi_ns.doc(params=query_params_from_model(SessionListQuery))
|
||||
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
# Validate page/limit through the same model the contract advertises (extra='forbid',
|
||||
# page>=1, 1<=limit<=MAX_PAGE_LIMIT) so the server actually enforces those bounds rather
|
||||
# than silently coercing (e.g. page=0 -> empty slice). Mirrors AppDescribeQuery.
|
||||
try:
|
||||
query = SessionListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
@returns(200, SessionListResponse, description="Session list")
|
||||
@accepts(query=SessionListQuery)
|
||||
def get(self, *, auth_data: AuthData, query: SessionListQuery):
|
||||
# SessionListQuery enforces the advertised bounds (extra='forbid', page>=1,
|
||||
# 1<=limit<=MAX_PAGE_LIMIT) so the server rejects out-of-range paging rather
|
||||
# than silently coercing (e.g. page=0 -> empty slice).
|
||||
ctx = get_auth_ctx()
|
||||
now = datetime.now(UTC)
|
||||
page = query.page
|
||||
@ -106,16 +99,19 @@ class AccountSessionsApi(Resource):
|
||||
for r in sliced
|
||||
]
|
||||
|
||||
return (
|
||||
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
|
||||
200,
|
||||
return SessionListResponse(
|
||||
page=page,
|
||||
limit=limit,
|
||||
total=total,
|
||||
has_more=page * limit < total,
|
||||
data=items,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/<string:session_id>")
|
||||
class AccountSessionByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@returns(200, RevokeResponse, description="Session revoked")
|
||||
def delete(self, session_id: str, *, auth_data: AuthData):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
@ -125,7 +121,7 @@ class AccountSessionByIdApi(Resource):
|
||||
raise NotFound("session not found")
|
||||
|
||||
revoke_oauth_token(db.session, redis_client, session_id)
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
return RevokeResponse(status="revoked")
|
||||
|
||||
|
||||
def _iso(dt: datetime | None) -> str | None:
|
||||
|
||||
@ -7,14 +7,13 @@ from collections.abc import Callable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
|
||||
|
||||
import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import AppRunRequest, TaskStopResponse
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
@ -123,23 +122,18 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/run")
|
||||
class AppRunApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
|
||||
@openapi_ns.response(200, "Run result (SSE stream)")
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, *, auth_data: AuthData):
|
||||
@openapi_ns.response(200, "Run result (SSE stream)")
|
||||
@accepts(body=AppRunRequest)
|
||||
def post(self, app_id: str, *, auth_data: AuthData, body: AppRunRequest):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
payload = AppRunRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
handler = _DISPATCH.get(app_model.mode)
|
||||
if handler is None:
|
||||
raise UnprocessableEntity("mode_not_runnable")
|
||||
|
||||
try:
|
||||
stream_obj = handler(app_model, caller, payload)
|
||||
stream_obj = handler(app_model, caller, body)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
@ -159,10 +153,10 @@ class AppRunApi(Resource):
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
|
||||
class AppRunTaskStopApi(Resource):
|
||||
@openapi_ns.response(200, "Task stopped", openapi_ns.models[TaskStopResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@returns(200, TaskStopResponse, description="Task stopped")
|
||||
def post(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
return {"result": "success"}
|
||||
return TaskStopResponse(result="success")
|
||||
|
||||
@ -5,14 +5,12 @@ from __future__ import annotations
|
||||
import uuid as _uuid
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
|
||||
from controllers.openapi._models import (
|
||||
AppDescribeInfo,
|
||||
@ -88,15 +86,11 @@ def parameters_payload(app: App) -> dict:
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/describe")
|
||||
class AppDescribeApi(AppReadResource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
|
||||
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, app_id: str, *, auth_data: AuthData):
|
||||
try:
|
||||
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
@returns(200, AppDescribeResponse, description="App description")
|
||||
@accepts(query=AppDescribeQuery)
|
||||
def get(self, app_id: str, *, auth_data: AuthData, query: AppDescribeQuery):
|
||||
# describe is UUID-only (workspace_id query param dropped in #37212).
|
||||
app = self._load(app_id)
|
||||
|
||||
requested = query.fields
|
||||
@ -133,35 +127,22 @@ class AppDescribeApi(AppReadResource):
|
||||
except AppUnavailableError:
|
||||
input_schema = dict(EMPTY_INPUT_SCHEMA)
|
||||
|
||||
return (
|
||||
AppDescribeResponse(
|
||||
info=info,
|
||||
parameters=parameters,
|
||||
input_schema=input_schema,
|
||||
).model_dump(mode="json", exclude_none=False),
|
||||
200,
|
||||
return AppDescribeResponse(
|
||||
info=info,
|
||||
parameters=parameters,
|
||||
input_schema=input_schema,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
|
||||
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
|
||||
@auth_router.guard_workspace(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
try:
|
||||
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
@returns(200, AppListResponse, description="App list")
|
||||
@accepts(query=AppListQuery)
|
||||
def get(self, *, auth_data: AuthData, query: AppListQuery):
|
||||
workspace_id = query.workspace_id
|
||||
|
||||
empty = (
|
||||
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
200,
|
||||
)
|
||||
empty = AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[])
|
||||
|
||||
if query.name:
|
||||
try:
|
||||
@ -189,7 +170,7 @@ class AppListApi(Resource):
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item])
|
||||
return env.model_dump(mode="json"), 200
|
||||
return env
|
||||
|
||||
tag_ids: list[str] | None = None
|
||||
if query.tag:
|
||||
@ -240,4 +221,4 @@ class AppListApi(Resource):
|
||||
has_more=query.page * query.limit < cast(int, pagination.total),
|
||||
data=items,
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
return env
|
||||
|
||||
@ -7,12 +7,10 @@ EE blueprint chain so this module is unreachable there.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import (
|
||||
AppListRow,
|
||||
PermittedExternalAppsListQuery,
|
||||
@ -30,20 +28,14 @@ from services.enterprise.app_permitted_service import list_permitted_apps
|
||||
|
||||
@openapi_ns.route("/permitted-external-apps")
|
||||
class PermittedExternalAppsListApi(Resource):
|
||||
@openapi_ns.response(
|
||||
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
|
||||
)
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_READ_PERMITTED_EXTERNAL,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}),
|
||||
edition=frozenset({Edition.EE}),
|
||||
)
|
||||
def get(self, *, auth_data: AuthData):
|
||||
try:
|
||||
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
@returns(200, PermittedExternalAppsListResponse, description="Permitted external apps list")
|
||||
@accepts(query=PermittedExternalAppsListQuery)
|
||||
def get(self, *, auth_data: AuthData, query: PermittedExternalAppsListQuery):
|
||||
page_result = list_permitted_apps(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
@ -55,7 +47,7 @@ class PermittedExternalAppsListApi(Resource):
|
||||
env = PermittedExternalAppsListResponse(
|
||||
page=query.page, limit=query.limit, total=page_result.total, has_more=False, data=[]
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
return env
|
||||
|
||||
apps_by_id: dict[str, App] = {
|
||||
str(a.id): a for a in AppService.find_visible_apps_by_ids(db.session, page_result.app_ids)
|
||||
@ -89,4 +81,4 @@ class PermittedExternalAppsListApi(Resource):
|
||||
has_more=query.page * query.limit < page_result.total,
|
||||
data=items,
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
return env
|
||||
|
||||
@ -17,6 +17,7 @@ from controllers.common.errors import (
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import returns
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
@ -38,8 +39,8 @@ class AppFileUploadApi(Resource):
|
||||
415: "Unsupported file type or blocked extension",
|
||||
}
|
||||
)
|
||||
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@returns(HTTPStatus.CREATED, FileResponse, description="File uploaded")
|
||||
def post(self, app_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, _ = auth_data.require_app_context()
|
||||
if "file" not in request.files:
|
||||
@ -69,5 +70,4 @@ class AppFileUploadApi(Resource):
|
||||
except services.errors.file.BlockedFileExtensionError as exc:
|
||||
raise BlockedFileExtensionError(exc.description)
|
||||
|
||||
response = FileResponse.model_validate(upload_file, from_attributes=True)
|
||||
return response.model_dump(mode="json"), 201
|
||||
return FileResponse.model_validate(upload_file, from_attributes=True)
|
||||
|
||||
@ -10,13 +10,14 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import Response, request
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import FormSubmitResponse
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
@ -70,12 +71,11 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
service.ensure_form_active(form)
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@openapi_ns.response(200, "Form submitted", openapi_ns.models[FormSubmitResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, form_token: str, *, auth_data: AuthData):
|
||||
@returns(200, FormSubmitResponse, description="Form submitted")
|
||||
@accepts(body=HumanInputFormSubmitPayload)
|
||||
def post(self, app_id: str, form_token: str, *, auth_data: AuthData, body: HumanInputFormSubmitPayload):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
@ -100,12 +100,12 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
service.submit_form_by_token(
|
||||
recipient_type=form.recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
selected_action_id=body.action,
|
||||
form_data=body.inputs,
|
||||
submission_user_id=submission_user_id,
|
||||
submission_end_user_id=submission_end_user_id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
return {}, 200
|
||||
return FormSubmitResponse()
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import returns
|
||||
from controllers.openapi._models import HealthResponse
|
||||
|
||||
|
||||
@openapi_ns.route("/_health")
|
||||
class HealthApi(Resource):
|
||||
@openapi_ns.response(200, "Health check", openapi_ns.models[HealthResponse.__name__])
|
||||
@returns(200, HealthResponse, description="Health check")
|
||||
def get(self):
|
||||
return {"ok": True}
|
||||
return HealthResponse(ok=True)
|
||||
|
||||
@ -14,14 +14,13 @@ from __future__ import annotations
|
||||
from itertools import starmap
|
||||
from urllib import parse
|
||||
|
||||
from flask import jsonify, make_response, request
|
||||
from flask import jsonify, make_response
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import (
|
||||
MemberActionResponse,
|
||||
MemberInvitePayload,
|
||||
@ -53,14 +52,6 @@ from services.errors.account import (
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _validate_body[M: BaseModel](model: type[M]) -> M:
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
return model.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
def _member_response(account: Account) -> MemberResponse:
|
||||
return MemberResponse(
|
||||
id=str(account.id),
|
||||
@ -118,18 +109,18 @@ def _check_member_invite_quota(tenant_id: str) -> None:
|
||||
|
||||
@openapi_ns.route("/workspaces")
|
||||
class WorkspacesApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@returns(200, WorkspaceListResponse, description="Workspace list")
|
||||
def get(self, *, auth_data: AuthData):
|
||||
rows = TenantService.get_workspaces_for_account(db.session, str(auth_data.account_id))
|
||||
|
||||
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
|
||||
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows)))
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>")
|
||||
class WorkspaceByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@returns(200, WorkspaceDetailResponse, description="Workspace detail")
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData):
|
||||
row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id)
|
||||
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
|
||||
@ -137,7 +128,7 @@ class WorkspaceByIdApi(Resource):
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
tenant, membership = row
|
||||
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
|
||||
return _workspace_detail(tenant, membership)
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/switch")
|
||||
@ -149,8 +140,8 @@ class WorkspaceSwitchApi(Resource):
|
||||
that ``hosts.yml`` never diverges from the server's ``current`` state.
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@returns(200, WorkspaceDetailResponse, description="Workspace detail")
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
account = _load_account(auth_data.account_id)
|
||||
|
||||
@ -163,7 +154,7 @@ class WorkspaceSwitchApi(Resource):
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
tenant, membership = row
|
||||
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
|
||||
return _workspace_detail(tenant, membership)
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members")
|
||||
@ -174,15 +165,10 @@ class WorkspaceMembersApi(Resource):
|
||||
assigned through invite (ownership transfer is console-only).
|
||||
"""
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(MemberListQuery))
|
||||
@openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__])
|
||||
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData):
|
||||
try:
|
||||
query = MemberListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
@returns(200, MemberListResponse, description="Member list")
|
||||
@accepts(query=MemberListQuery)
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData, query: MemberListQuery):
|
||||
tenant = _load_tenant(workspace_id)
|
||||
members = TenantService.get_tenant_members(tenant)
|
||||
total = len(members)
|
||||
@ -194,17 +180,16 @@ class WorkspaceMembersApi(Resource):
|
||||
total=total,
|
||||
has_more=query.page * query.limit < total,
|
||||
data=[_member_response(m) for m in page_items],
|
||||
).model_dump(mode="json"), 200
|
||||
)
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__])
|
||||
@openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__])
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberInvitePayload)
|
||||
@returns(201, MemberInviteResponse, description="Member invited")
|
||||
@accepts(body=MemberInvitePayload)
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData, body: MemberInvitePayload):
|
||||
inviter = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
|
||||
@ -213,9 +198,9 @@ class WorkspaceMembersApi(Resource):
|
||||
try:
|
||||
token = RegisterService.invite_new_member(
|
||||
tenant=tenant,
|
||||
email=payload.email,
|
||||
email=body.email,
|
||||
language=None,
|
||||
role=payload.role,
|
||||
role=body.role,
|
||||
inviter=inviter,
|
||||
)
|
||||
except AccountAlreadyInTenantError as exc:
|
||||
@ -225,7 +210,7 @@ class WorkspaceMembersApi(Resource):
|
||||
except AccountRegisterError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
normalized_email = payload.email.lower()
|
||||
normalized_email = body.email.lower()
|
||||
member = AccountService.get_account_by_email_with_case_fallback(normalized_email)
|
||||
if member is None:
|
||||
# invite_new_member just created or fetched this account.
|
||||
@ -235,11 +220,11 @@ class WorkspaceMembersApi(Resource):
|
||||
invite_url = f"{dify_config.CONSOLE_WEB_URL}/activate?email={encoded_email}&token={token}"
|
||||
return MemberInviteResponse(
|
||||
email=normalized_email,
|
||||
role=payload.role,
|
||||
role=body.role,
|
||||
member_id=str(member.id),
|
||||
invite_url=invite_url,
|
||||
tenant_id=str(tenant.id),
|
||||
).model_dump(mode="json"), 201
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members/<string:member_id>")
|
||||
@ -251,12 +236,12 @@ class WorkspaceMemberApi(Resource):
|
||||
400 per the spec, with the service's message preserved.
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
@returns(200, MemberActionResponse, description="Member removed")
|
||||
def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
operator = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
@ -273,7 +258,7 @@ class WorkspaceMemberApi(Resource):
|
||||
except MemberNotInTenantError as exc:
|
||||
raise NotFound(str(exc))
|
||||
|
||||
return MemberActionResponse().model_dump(mode="json"), 200
|
||||
return MemberActionResponse()
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members/<string:member_id>/role")
|
||||
@ -284,15 +269,14 @@ class WorkspaceMemberRoleApi(Resource):
|
||||
standing owner (service NoPermissionError → 400, per spec).
|
||||
"""
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__])
|
||||
@openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberRoleUpdatePayload)
|
||||
@returns(200, MemberActionResponse, description="Role updated")
|
||||
@accepts(body=MemberRoleUpdatePayload)
|
||||
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData, body: MemberRoleUpdatePayload):
|
||||
operator = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
member = AccountService.get_account_by_id(db.session, member_id)
|
||||
@ -300,7 +284,7 @@ class WorkspaceMemberRoleApi(Resource):
|
||||
raise NotFound("member not found")
|
||||
|
||||
try:
|
||||
TenantService.update_member_role(tenant, member, payload.role, operator)
|
||||
TenantService.update_member_role(tenant, member, body.role, operator)
|
||||
except CannotOperateSelfError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except NoPermissionError as exc:
|
||||
@ -310,7 +294,7 @@ class WorkspaceMemberRoleApi(Resource):
|
||||
except RoleAlreadyAssignedError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
return MemberActionResponse().model_dump(mode="json"), 200
|
||||
return MemberActionResponse()
|
||||
|
||||
|
||||
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceSummaryResponse:
|
||||
|
||||
@ -299,6 +299,15 @@ Upload a file to use as an input variable when running the app
|
||||
### /permitted-external-apps
|
||||
|
||||
#### GET
|
||||
##### Parameters
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| limit | query | | No | integer |
|
||||
| mode | query | | No | string |
|
||||
| name | query | | No | string |
|
||||
| page | query | | No | integer |
|
||||
|
||||
##### Responses
|
||||
|
||||
| Code | Description | Schema |
|
||||
|
||||
@ -94,4 +94,4 @@ def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, mo
|
||||
|
||||
queue_mock.set_stop_flag_no_user_check.assert_called_once_with("task-1")
|
||||
graph_instance.send_stop_command.assert_called_once_with("task-1")
|
||||
assert result == {"result": "success"}
|
||||
assert result == ({"result": "success"}, 200)
|
||||
|
||||
210
api/tests/unit_tests/controllers/openapi/test_contract.py
Normal file
210
api/tests/unit_tests/controllers/openapi/test_contract.py
Normal file
@ -0,0 +1,210 @@
|
||||
"""Unit tests for the @accepts / @returns contract decorators.
|
||||
|
||||
Exercises the decorators in isolation (not through a real controller): a plain
|
||||
view function decorated with @accepts/@returns, driven inside a request context.
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
from controllers.common.schema import register_response_schema_model, register_schema_model
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
|
||||
|
||||
class ContractQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=100)
|
||||
|
||||
|
||||
class ContractBody(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class ContractResp(BaseModel):
|
||||
value: int
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def _register_contract_test_models():
|
||||
# Register for @accepts(body=)/@returns name lookups; drop on teardown so these
|
||||
# test-only models don't leak into the shared openapi_ns / generated spec.
|
||||
register_schema_model(openapi_ns, ContractBody)
|
||||
register_response_schema_model(openapi_ns, ContractResp)
|
||||
yield
|
||||
openapi_ns.models.pop(ContractBody.__name__, None)
|
||||
openapi_ns.models.pop(ContractResp.__name__, None)
|
||||
|
||||
|
||||
def _guard_like(view):
|
||||
"""Stand-in for ``@auth_router.guard`` — an outermost @wraps layer."""
|
||||
|
||||
@wraps(view)
|
||||
def wrapper(*args, **kwargs):
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def test_accepts_injects_validated_query(app):
|
||||
@accepts(query=ContractQuery)
|
||||
def view(*, query):
|
||||
return query
|
||||
|
||||
with app.test_request_context("/?page=3&limit=5"):
|
||||
result = view()
|
||||
|
||||
assert isinstance(result, ContractQuery)
|
||||
assert result.page == 3
|
||||
assert result.limit == 5
|
||||
|
||||
|
||||
def test_accepts_query_uses_defaults_when_absent(app):
|
||||
@accepts(query=ContractQuery)
|
||||
def view(*, query):
|
||||
return query
|
||||
|
||||
with app.test_request_context("/"):
|
||||
result = view()
|
||||
|
||||
assert result.page == 1
|
||||
assert result.limit == 20
|
||||
|
||||
|
||||
@pytest.mark.parametrize("query_string", ["page=0", "limit=999", "page=abc", "unknown=1"])
|
||||
def test_accepts_rejects_invalid_query_with_422(app, query_string):
|
||||
@accepts(query=ContractQuery)
|
||||
def view(*, query):
|
||||
return query
|
||||
|
||||
with app.test_request_context(f"/?{query_string}"):
|
||||
with pytest.raises(UnprocessableEntity):
|
||||
view()
|
||||
|
||||
|
||||
def test_accepts_validation_error_is_sanitized_and_structured(app):
|
||||
"""422 body is structured and leaks neither the pydantic docs url nor the user input."""
|
||||
|
||||
@accepts(body=ContractBody)
|
||||
def view(*, body):
|
||||
return body
|
||||
|
||||
with app.test_request_context("/", method="POST", json={"secret": "leak-me"}):
|
||||
with pytest.raises(UnprocessableEntity) as exc_info:
|
||||
view()
|
||||
|
||||
data = exc_info.value.data
|
||||
assert data["message"] == "Request validation failed"
|
||||
assert isinstance(data["errors"], list)
|
||||
assert data["errors"]
|
||||
for err in data["errors"]:
|
||||
assert {"type", "loc", "msg"} <= err.keys()
|
||||
assert "url" not in err
|
||||
assert "input" not in err
|
||||
assert "leak-me" not in str(data)
|
||||
|
||||
|
||||
def test_accepts_injects_validated_body(app):
|
||||
@accepts(body=ContractBody)
|
||||
def view(*, body):
|
||||
return body
|
||||
|
||||
with app.test_request_context("/", method="POST", json={"name": "x"}):
|
||||
result = view()
|
||||
|
||||
assert isinstance(result, ContractBody)
|
||||
assert result.name == "x"
|
||||
|
||||
|
||||
def test_accepts_rejects_invalid_body_with_422(app):
|
||||
@accepts(body=ContractBody)
|
||||
def view(*, body):
|
||||
return body
|
||||
|
||||
with app.test_request_context("/", method="POST", json={"wrong": 1}):
|
||||
with pytest.raises(UnprocessableEntity):
|
||||
view()
|
||||
|
||||
|
||||
def test_returns_serializes_model_with_decorator_status(app):
|
||||
@returns(200, ContractResp)
|
||||
def view():
|
||||
return ContractResp(value=7)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
body, status = view()
|
||||
|
||||
assert status == 200
|
||||
assert body == {"value": 7}
|
||||
|
||||
|
||||
def test_returns_serializes_model_in_tuple_and_honors_status(app):
|
||||
@returns(200, ContractResp)
|
||||
def view():
|
||||
return ContractResp(value=9), 201
|
||||
|
||||
with app.test_request_context("/"):
|
||||
body, status = view()
|
||||
|
||||
assert status == 201
|
||||
assert body == {"value": 9}
|
||||
|
||||
|
||||
def test_returns_passes_through_non_model(app):
|
||||
sentinel = object()
|
||||
|
||||
@returns(200, ContractResp)
|
||||
def view():
|
||||
return sentinel
|
||||
|
||||
with app.test_request_context("/"):
|
||||
result = view()
|
||||
|
||||
assert result is sentinel
|
||||
|
||||
|
||||
def test_returns_serializes_model_in_three_tuple_with_headers(app):
|
||||
"""A (model, status, headers) tuple keeps its trailing status/headers intact."""
|
||||
|
||||
@returns(200, ContractResp)
|
||||
def view():
|
||||
return ContractResp(value=3), 202, {"X-Test": "1"}
|
||||
|
||||
with app.test_request_context("/"):
|
||||
body, status, headers = view()
|
||||
|
||||
assert body == {"value": 3}
|
||||
assert status == 202
|
||||
assert headers == {"X-Test": "1"}
|
||||
|
||||
|
||||
# Swagger metadata (read off __apidoc__) must survive @wraps up through the guard layer.
|
||||
|
||||
|
||||
def test_accepts_returns_emit_apidoc_through_guard_stack():
|
||||
@_guard_like
|
||||
@returns(200, ContractResp)
|
||||
@accepts(query=ContractQuery)
|
||||
def view(*, query):
|
||||
return ContractResp(value=1)
|
||||
|
||||
apidoc = getattr(view, "__apidoc__", {})
|
||||
assert "page" in apidoc.get("params", {}) # from @accepts(query=)
|
||||
assert "200" in apidoc.get("responses", {}) # from @returns (flask_restx keys by str code)
|
||||
|
||||
|
||||
def test_accepts_body_emits_expect_through_guard_stack():
|
||||
@_guard_like
|
||||
@accepts(body=ContractBody)
|
||||
def view(*, body):
|
||||
return body
|
||||
|
||||
apidoc = getattr(view, "__apidoc__", {})
|
||||
assert apidoc.get("expect") # body schema advertised via @openapi_ns.expect
|
||||
@ -11,7 +11,7 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
from werkzeug.exceptions import NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
@ -233,3 +233,24 @@ class TestOpenApiHumanInputFormPost:
|
||||
submission_end_user_id="eu-7",
|
||||
)
|
||||
assert result == ({}, 200)
|
||||
|
||||
def test_post_rejects_invalid_body_with_422(self, app: Flask, bypass_pipeline):
|
||||
"""Malformed body → 422 via @accepts (was an unmapped pydantic error → 500)."""
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
caller = SimpleNamespace(id="acct-42")
|
||||
|
||||
with app.test_request_context(
|
||||
"/openapi/v1/apps/app-1/form/human_input/tok-1",
|
||||
method="POST",
|
||||
json={"inputs": {"field1": "val"}}, # missing required "action"
|
||||
):
|
||||
with pytest.raises(UnprocessableEntity):
|
||||
api.post.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
|
||||
@ -29,7 +29,7 @@ import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi._models import MemberInvitePayload, MemberRoleUpdatePayload
|
||||
@ -198,7 +198,7 @@ def test_member_role_route_registered(openapi_app: Flask):
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Payload validation lands at 400
|
||||
# Payload validation lands at 422 (unified via @accepts)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@ -227,18 +227,38 @@ def test_role_payload_rejects_extra_field():
|
||||
MemberRoleUpdatePayload.model_validate({"role": "normal", "extra": "x"})
|
||||
|
||||
|
||||
def test_validate_body_helper_maps_validation_error_to_400(app, monkeypatch):
|
||||
"""`_validate_body` is the centralized 400-mapper for invalid request bodies."""
|
||||
from controllers.openapi.workspaces import _validate_body
|
||||
def test_invite_rejects_invalid_body_with_422(app, bypass_pipeline):
|
||||
"""Invalid invite body → 422 via @accepts (was 400 through _validate_body)."""
|
||||
ws_id = str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
api = WorkspaceMembersApi()
|
||||
|
||||
with app.test_request_context(
|
||||
"/openapi/v1/workspaces/ws-1/members",
|
||||
f"/openapi/v1/workspaces/{ws_id}/members",
|
||||
method="POST",
|
||||
data=json.dumps({"email": "u@example.com", "role": "owner"}),
|
||||
data=json.dumps({"email": "u@example.com", "role": "owner"}), # owner is not invite-assignable
|
||||
content_type="application/json",
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
_validate_body(MemberInvitePayload)
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(UnprocessableEntity):
|
||||
api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
|
||||
def test_update_role_rejects_invalid_body_with_422(app, bypass_pipeline):
|
||||
"""Invalid role-update body surfaces as 422 through @accepts (was 400)."""
|
||||
ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
api = WorkspaceMemberRoleApi()
|
||||
|
||||
with app.test_request_context(
|
||||
f"/openapi/v1/workspaces/{ws_id}/members/{member_id}/role",
|
||||
method="PUT",
|
||||
data=json.dumps({"role": "owner"}), # closed enum rejects owner
|
||||
content_type="application/json",
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(UnprocessableEntity):
|
||||
api.put.__wrapped__(api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -384,7 +404,7 @@ def test_members_list_paginates_with_query_params(app, bypass_pipeline, monkeypa
|
||||
|
||||
|
||||
def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypatch):
|
||||
"""Strict (`extra='forbid'`) — typos like `?pg=2` surface as 400."""
|
||||
"""Strict (`extra='forbid'`) — typos like `?pg=2` surface as 422 (unified via @accepts)."""
|
||||
ws_id = str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
api = WorkspaceMembersApi()
|
||||
@ -395,7 +415,7 @@ def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypa
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members?pg=2"):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(BadRequest):
|
||||
with pytest.raises(UnprocessableEntity):
|
||||
api.get.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import type { AccountContext } from './hosts'
|
||||
import type { TokenStore } from '@/store/token-store'
|
||||
import type { Key, Store } from '@/store/store'
|
||||
import { mkdtemp, rm } from 'node:fs/promises'
|
||||
import { tmpdir } from 'node:os'
|
||||
import { join } from 'node:path'
|
||||
@ -53,20 +53,6 @@ describe('RegistrySchema', () => {
|
||||
})
|
||||
expect(ctx.external_subject?.issuer).toBe('https://issuer')
|
||||
})
|
||||
|
||||
it('strips a stale available_workspaces field from legacy contexts', () => {
|
||||
const raw = {
|
||||
account: { id: 'acct-1', email: 'bob@corp.com', name: 'Bob' },
|
||||
workspace: { id: 'ws-1', name: 'Space', role: 'owner' },
|
||||
available_workspaces: [
|
||||
{ id: 'ws-1', name: 'Space', role: 'owner' },
|
||||
{ id: '00000000-0000-0000-0000-000000000002', name: 'Other', role: 'normal' },
|
||||
],
|
||||
} as unknown as Record<string, unknown>
|
||||
const ctx = AccountContextSchema.parse(raw)
|
||||
expect((ctx as Record<string, unknown>).available_workspaces).toBeUndefined()
|
||||
expect(ctx.workspace?.id).toBe('ws-1')
|
||||
})
|
||||
})
|
||||
|
||||
describe('notLoggedInError', () => {
|
||||
@ -172,12 +158,11 @@ describe('Registry.load / Registry.save', () => {
|
||||
})
|
||||
})
|
||||
|
||||
class MemStore implements TokenStore {
|
||||
readonly entries = new Map<string, string>()
|
||||
private k(host: string, email: string): string { return `${host} ${email}` }
|
||||
read(host: string, email: string): string { return this.entries.get(this.k(host, email)) ?? '' }
|
||||
write(host: string, email: string, bearer: string): void { this.entries.set(this.k(host, email), bearer) }
|
||||
remove(host: string, email: string): void { this.entries.delete(this.k(host, email)) }
|
||||
class MemStore implements Store {
|
||||
readonly entries = new Map<string, unknown>()
|
||||
get<T>(key: Key<T>): T { return (this.entries.get(key.key) as T | undefined) ?? key.default }
|
||||
set<T>(key: Key<T>, value: T): void { this.entries.set(key.key, value) }
|
||||
unset<T>(key: Key<T>): void { this.entries.delete(key.key) }
|
||||
}
|
||||
|
||||
describe('Registry.forget', () => {
|
||||
@ -203,12 +188,12 @@ describe('Registry.forget', () => {
|
||||
reg.setHost('h1')
|
||||
reg.setAccount('a@x')
|
||||
reg.save()
|
||||
store.write('h1', 'a@x', 'dfoa_a')
|
||||
store.set({ key: 'tokens.h1.a@x', default: '' }, 'dfoa_a')
|
||||
|
||||
const active = reg.resolveActive()!
|
||||
reg.forget(active, store)
|
||||
|
||||
expect(store.read('h1', 'a@x')).toBe('')
|
||||
expect(store.get({ key: 'tokens.h1.a@x', default: '' })).toBe('')
|
||||
const after = Registry.load()
|
||||
expect(after?.hosts.h1?.accounts['a@x']).toBeUndefined()
|
||||
expect(after?.hosts.h1?.accounts['b@x']).toBeDefined()
|
||||
|
||||
@ -1,14 +1,11 @@
|
||||
import type { StorageMode } from '@/store/store'
|
||||
import type { TokenStore } from '@/store/token-store'
|
||||
import type { Store } from '@/store/store'
|
||||
import { z } from 'zod'
|
||||
import { BaseError } from '@/errors/base'
|
||||
import { ErrorCode } from '@/errors/codes'
|
||||
import { getHostStore } from '@/store/manager'
|
||||
import { STORAGE_MODES } from '@/store/store'
|
||||
import { getHostStore, tokenKey } from '@/store/manager'
|
||||
|
||||
const StorageModeSchema = z.enum(STORAGE_MODES)
|
||||
|
||||
export type { StorageMode } from '@/store/store'
|
||||
const StorageModeSchema = z.enum(['keychain', 'file'])
|
||||
export type StorageMode = z.infer<typeof StorageModeSchema>
|
||||
|
||||
export const AccountSchema = z.object({
|
||||
id: z.string().optional(),
|
||||
@ -33,6 +30,7 @@ export type ExternalSubject = z.infer<typeof ExternalSubjectSchema>
|
||||
export const AccountContextSchema = z.object({
|
||||
account: AccountSchema,
|
||||
workspace: WorkspaceSchema.optional(),
|
||||
available_workspaces: z.array(WorkspaceSchema).optional(),
|
||||
token_id: z.string().optional(),
|
||||
token_expires_at: z.string().optional(),
|
||||
external_subject: ExternalSubjectSchema.optional(),
|
||||
@ -165,9 +163,9 @@ export class Registry {
|
||||
|
||||
// Teardown for "this credential is gone": drop the token, drop the context
|
||||
// (unsets pointers when active), persist. Logout + self-revoke share it.
|
||||
forget(active: ActiveContext, store: TokenStore): void {
|
||||
forget(active: ActiveContext, store: Store): void {
|
||||
try {
|
||||
store.remove(active.host, active.email)
|
||||
store.unset(tokenKey(active.host, active.email))
|
||||
}
|
||||
catch { /* best-effort */ }
|
||||
this.remove(active.host, active.email)
|
||||
|
||||
@ -2,7 +2,7 @@ import type { ActiveContext } from '@/auth/hosts'
|
||||
import type { AppInfoCache } from '@/cache/app-info'
|
||||
import type { Command } from '@/framework/command'
|
||||
import type { HttpClient } from '@/http/types'
|
||||
import type { TokenStore } from '@/store/token-store'
|
||||
import type { Store } from '@/store/store'
|
||||
import type { IOStreams } from '@/sys/io/streams'
|
||||
import { META_PROBE_TIMEOUT_MS, MetaClient } from '@/api/meta'
|
||||
import { notLoggedInError, Registry } from '@/auth/hosts'
|
||||
@ -11,7 +11,7 @@ import { loadNudgeStore } from '@/cache/nudge-store'
|
||||
import { getEnv } from '@/env/registry'
|
||||
import { formatErrorForCli } from '@/errors/format'
|
||||
import { createHttpClient } from '@/http/client'
|
||||
import { getTokenStore } from '@/store/manager'
|
||||
import { getTokenStore, tokenKey } from '@/store/manager'
|
||||
import { realStreams } from '@/sys/io/streams'
|
||||
import { hostWithScheme, openAPIBase } from '@/util/host'
|
||||
import { versionInfo } from '@/version/info'
|
||||
@ -21,7 +21,7 @@ import { resolveRetryAttempts } from './global-flags.js'
|
||||
export type AuthedContext = {
|
||||
readonly reg: Registry
|
||||
readonly active: ActiveContext
|
||||
readonly store: TokenStore
|
||||
readonly store: Store
|
||||
readonly http: HttpClient
|
||||
readonly host: string
|
||||
readonly io: IOStreams
|
||||
@ -44,8 +44,8 @@ export async function buildAuthedContext(
|
||||
if (active === undefined)
|
||||
fail(cmd, opts, io)
|
||||
|
||||
const store = getTokenStore(reg.token_storage)
|
||||
const bearer = store.read(active.host, active.email)
|
||||
const { store } = getTokenStore()
|
||||
const bearer = store.get(tokenKey(active.host, active.email))
|
||||
if (bearer === '')
|
||||
fail(cmd, opts, io)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import type { SessionListResponse, SessionRow } from '@dify/contracts/api/openap
|
||||
import type { DifyMock } from '@test/fixtures/dify-mock/server'
|
||||
import type { AccountSessionsClient } from '@/api/account-sessions'
|
||||
import type { ActiveContext } from '@/auth/hosts'
|
||||
import type { TokenStore } from '@/store/token-store'
|
||||
import type { Key, Store } from '@/store/store'
|
||||
import { mkdtemp, rm } from 'node:fs/promises'
|
||||
import { tmpdir } from 'node:os'
|
||||
import { join } from 'node:path'
|
||||
@ -11,25 +11,22 @@ import { testHttpClient } from '@test/fixtures/http-client'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { Registry } from '@/auth/hosts'
|
||||
import { ENV_CONFIG_DIR } from '@/store/dir'
|
||||
import { tokenKey } from '@/store/manager'
|
||||
import { bufferStreams } from '@/sys/io/streams'
|
||||
import { listAllSessions, runDevicesList, runDevicesRevoke } from './devices.js'
|
||||
|
||||
class MemStore implements TokenStore {
|
||||
readonly entries = new Map<string, string>()
|
||||
private k(host: string, email: string): string {
|
||||
return `${host} ${email}`
|
||||
class MemStore implements Store {
|
||||
readonly entries = new Map<string, unknown>()
|
||||
get<T>(key: Key<T>): T {
|
||||
return (this.entries.get(key.key) as T | undefined) ?? key.default
|
||||
}
|
||||
|
||||
read(host: string, email: string): string {
|
||||
return this.entries.get(this.k(host, email)) ?? ''
|
||||
set<T>(key: Key<T>, value: T): void {
|
||||
this.entries.set(key.key, value)
|
||||
}
|
||||
|
||||
write(host: string, email: string, bearer: string): void {
|
||||
this.entries.set(this.k(host, email), bearer)
|
||||
}
|
||||
|
||||
remove(host: string, email: string): void {
|
||||
this.entries.delete(this.k(host, email))
|
||||
unset<T>(key: Key<T>): void {
|
||||
this.entries.delete(key.key)
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,6 +35,10 @@ function buildRegistry(host: string, email: string, tokenId: string): { reg: Reg
|
||||
reg.upsert(host, email, {
|
||||
account: { id: 'acct-1', email, name: 'Test Tester' },
|
||||
workspace: { id: 'ws-1', name: 'Default', role: 'owner' },
|
||||
available_workspaces: [
|
||||
{ id: 'ws-1', name: 'Default', role: 'owner' },
|
||||
{ id: 'ws-2', name: 'Other', role: 'normal' },
|
||||
],
|
||||
token_id: tokenId,
|
||||
})
|
||||
reg.setHost(host)
|
||||
@ -102,7 +103,7 @@ describe('runDevicesRevoke', () => {
|
||||
const io = bufferStreams()
|
||||
const store = new MemStore()
|
||||
const { reg, active } = buildRegistry(mock.url, 'tester@dify.ai', 'tok-1')
|
||||
store.write(mock.url, 'tester@dify.ai', 'dfoa_test')
|
||||
store.set(tokenKey(mock.url, 'tester@dify.ai'), 'dfoa_test')
|
||||
reg.save()
|
||||
const http = testHttpClient(mock.url, 'dfoa_test')
|
||||
|
||||
@ -167,7 +168,7 @@ describe('runDevicesRevoke', () => {
|
||||
const io = bufferStreams()
|
||||
const store = new MemStore()
|
||||
const { reg, active } = buildRegistry(mock.url, 'tester@dify.ai', 'tok-1')
|
||||
store.write(mock.url, 'tester@dify.ai', 'dfoa_test')
|
||||
store.set(tokenKey(mock.url, 'tester@dify.ai'), 'dfoa_test')
|
||||
reg.save()
|
||||
const http = testHttpClient(mock.url, 'dfoa_test')
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import type { SessionRow } from '@dify/contracts/api/openapi/types.gen'
|
||||
import type { ActiveContext, Registry } from '@/auth/hosts'
|
||||
import type { HttpClient } from '@/http/types'
|
||||
import type { TokenStore } from '@/store/token-store'
|
||||
import type { Store } from '@/store/store'
|
||||
import type { IOStreams } from '@/sys/io/streams'
|
||||
import { AccountSessionsClient } from '@/api/account-sessions'
|
||||
import { BaseError } from '@/errors/base'
|
||||
@ -71,7 +71,7 @@ export type DevicesRevokeOptions = {
|
||||
readonly io: IOStreams
|
||||
readonly reg: Registry
|
||||
readonly active: ActiveContext
|
||||
readonly store: TokenStore
|
||||
readonly store: Store
|
||||
readonly http: HttpClient
|
||||
readonly target?: string
|
||||
readonly all: boolean
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import type { DifyMock } from '@test/fixtures/dify-mock/server'
|
||||
import type { Clock } from './device-flow.js'
|
||||
import type { TokenStore } from '@/store/token-store'
|
||||
import type { Key, Store } from '@/store/store'
|
||||
import { mkdtemp, readFile, rm } from 'node:fs/promises'
|
||||
import { tmpdir } from 'node:os'
|
||||
import { join } from 'node:path'
|
||||
@ -10,6 +10,7 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'
|
||||
import { DeviceFlowApi } from '@/api/oauth-device'
|
||||
import { createHttpClient } from '@/http/client'
|
||||
import { ENV_CONFIG_DIR } from '@/store/dir'
|
||||
import { tokenKey } from '@/store/manager'
|
||||
import { bufferStreams } from '@/sys/io/streams'
|
||||
import { openAPIBase } from '@/util/host'
|
||||
import { runLogin } from './login.js'
|
||||
@ -21,22 +22,18 @@ const noopClock: Clock = {
|
||||
|
||||
const noopBrowser = async (): Promise<void> => { /* skip OS open */ }
|
||||
|
||||
class MemStore implements TokenStore {
|
||||
readonly entries = new Map<string, string>()
|
||||
private k(host: string, email: string): string {
|
||||
return `${host} ${email}`
|
||||
class MemStore implements Store {
|
||||
readonly entries = new Map<string, unknown>()
|
||||
get<T>(key: Key<T>): T {
|
||||
return (this.entries.get(key.key) as T | undefined) ?? key.default
|
||||
}
|
||||
|
||||
read(host: string, email: string): string {
|
||||
return this.entries.get(this.k(host, email)) ?? ''
|
||||
set<T>(key: Key<T>, value: T): void {
|
||||
this.entries.set(key.key, value)
|
||||
}
|
||||
|
||||
write(host: string, email: string, bearer: string): void {
|
||||
this.entries.set(this.k(host, email), bearer)
|
||||
}
|
||||
|
||||
remove(host: string, email: string): void {
|
||||
this.entries.delete(this.k(host, email))
|
||||
unset<T>(key: Key<T>): void {
|
||||
this.entries.delete(key.key)
|
||||
}
|
||||
}
|
||||
|
||||
@ -78,7 +75,8 @@ describe('runLogin', () => {
|
||||
const active = reg.resolveActive()
|
||||
expect(active?.ctx.account.email).toBe('tester@dify.ai')
|
||||
expect(active?.ctx.workspace?.id).toBe('550e8400-e29b-41d4-a716-446655440000')
|
||||
expect(store.read(active!.host, 'tester@dify.ai')).toBe('dfoa_test')
|
||||
expect(active?.ctx.available_workspaces).toHaveLength(2)
|
||||
expect(store.get(tokenKey(active!.host, 'tester@dify.ai'))).toBe('dfoa_test')
|
||||
|
||||
const hostsRaw = await readFile(join(configDir, 'hosts.yml'), 'utf8')
|
||||
expect(hostsRaw).toContain('current_host:')
|
||||
@ -111,7 +109,7 @@ describe('runLogin', () => {
|
||||
expect(active?.ctx.external_subject?.email).toBe('sso@dify.ai')
|
||||
expect(active?.ctx.external_subject?.issuer).toBe('https://issuer.example')
|
||||
expect(active?.ctx.account.email).toBe('')
|
||||
expect(store.read(active!.host, 'sso@dify.ai')).toBe('dfoe_test')
|
||||
expect(store.get(tokenKey(active!.host, 'sso@dify.ai'))).toBe('dfoe_test')
|
||||
expect(io.outBuf()).toContain('external SSO')
|
||||
expect(io.outBuf()).toContain('sso@dify.ai')
|
||||
})
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import type { Clock } from './device-flow.js'
|
||||
import type { CodeResponse, PollSuccess } from '@/api/oauth-device'
|
||||
import type { AccountContext } from '@/auth/hosts'
|
||||
import type { StorageMode } from '@/store/store'
|
||||
import type { TokenStore } from '@/store/token-store'
|
||||
import type { AccountContext, Workspace } from '@/auth/hosts'
|
||||
import type { StorageMode, Store } from '@/store/store'
|
||||
import type { ParseResult } from '@/sys/io/prompt'
|
||||
import type { IOStreams } from '@/sys/io/streams'
|
||||
import type { BrowserEnv, BrowserOpener } from '@/util/browser'
|
||||
@ -12,7 +11,7 @@ import { Registry } from '@/auth/hosts'
|
||||
import { BaseError, isBaseError } from '@/errors/base'
|
||||
import { ErrorCode } from '@/errors/codes'
|
||||
import { createHttpClient } from '@/http/client'
|
||||
import { detectTokenStore } from '@/store/manager'
|
||||
import { getTokenStore, tokenKey } from '@/store/manager'
|
||||
import { colorEnabled, colorScheme } from '@/sys/io/color'
|
||||
import { promptText } from '@/sys/io/prompt'
|
||||
import { startSpinner } from '@/sys/io/spinner'
|
||||
@ -26,7 +25,7 @@ export type LoginOptions = {
|
||||
readonly noBrowser?: boolean
|
||||
readonly insecure?: boolean
|
||||
readonly deviceLabel?: string
|
||||
readonly store?: { readonly store: TokenStore, readonly mode: StorageMode }
|
||||
readonly store?: { readonly store: Store, readonly mode: StorageMode }
|
||||
readonly api?: DeviceFlowApi
|
||||
readonly browserEnv?: BrowserEnv
|
||||
readonly browserOpener?: BrowserOpener
|
||||
@ -70,12 +69,12 @@ export async function runLogin(opts: LoginOptions): Promise<Registry> {
|
||||
spinner.stop()
|
||||
}
|
||||
|
||||
const storeBundle = opts.store ?? detectTokenStore()
|
||||
const storeBundle = opts.store ?? getTokenStore()
|
||||
const display = bareHost(host)
|
||||
const email = accountEmail(success)
|
||||
const ctx = contextFromSuccess(success)
|
||||
|
||||
storeBundle.store.write(display, email, success.token)
|
||||
storeBundle.store.set(tokenKey(display, email), success.token)
|
||||
|
||||
const reg = Registry.load()
|
||||
reg.token_storage = storeBundle.mode
|
||||
@ -188,6 +187,9 @@ function contextFromSuccess(s: PollSuccess): AccountContext {
|
||||
const def = findDefaultWorkspace(s)
|
||||
if (def !== undefined)
|
||||
ctx.workspace = def
|
||||
if (s.workspaces !== undefined && s.workspaces.length > 0) {
|
||||
ctx.available_workspaces = s.workspaces.map<Workspace>(w => ({ id: w.id, name: w.name, role: w.role }))
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import type { HttpClient } from '@/http/types'
|
||||
import { Registry } from '@/auth/hosts'
|
||||
import { DifyCommand } from '@/commands/_shared/dify-command'
|
||||
import { createHttpClient } from '@/http/client'
|
||||
import { getTokenStore } from '@/store/manager'
|
||||
import { getTokenStore, tokenKey } from '@/store/manager'
|
||||
import { runWithSpinner } from '@/sys/io/spinner'
|
||||
import { realStreams } from '@/sys/io/streams'
|
||||
import { hostWithScheme, openAPIBase } from '@/util/host'
|
||||
@ -26,11 +26,7 @@ export default class Logout extends DifyCommand {
|
||||
|
||||
let http: HttpClient | undefined
|
||||
if (active !== undefined) {
|
||||
let bearer = ''
|
||||
try {
|
||||
bearer = getTokenStore(reg.token_storage).read(active.host, active.email)
|
||||
}
|
||||
catch { /* keyring locked — skip remote revocation, local cleanup still runs */ }
|
||||
const bearer = getTokenStore().store.get(tokenKey(active.host, active.email))
|
||||
if (bearer !== '') {
|
||||
http = createHttpClient({ baseURL: openAPIBase(hostWithScheme(active.host, active.scheme)), bearer, retryAttempts: 0 })
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { TokenStore } from '@/store/token-store'
|
||||
import type { Key, Store } from '@/store/store'
|
||||
import { mkdtemp, readFile, rm } from 'node:fs/promises'
|
||||
import { tmpdir } from 'node:os'
|
||||
import { join } from 'node:path'
|
||||
@ -8,12 +8,11 @@ import { ENV_CONFIG_DIR } from '@/store/dir'
|
||||
import { bufferStreams } from '@/sys/io/streams'
|
||||
import { runLogout } from './logout.js'
|
||||
|
||||
class MemStore implements TokenStore {
|
||||
readonly entries = new Map<string, string>()
|
||||
private k(host: string, email: string): string { return `${host} ${email}` }
|
||||
read(host: string, email: string): string { return this.entries.get(this.k(host, email)) ?? '' }
|
||||
write(host: string, email: string, bearer: string): void { this.entries.set(this.k(host, email), bearer) }
|
||||
remove(host: string, email: string): void { this.entries.delete(this.k(host, email)) }
|
||||
class MemStore implements Store {
|
||||
readonly entries = new Map<string, unknown>()
|
||||
get<T>(key: Key<T>): T { return (this.entries.get(key.key) as T | undefined) ?? key.default }
|
||||
set<T>(key: Key<T>, value: T): void { this.entries.set(key.key, value) }
|
||||
unset<T>(key: Key<T>): void { this.entries.delete(key.key) }
|
||||
}
|
||||
|
||||
describe('runLogout', () => {
|
||||
@ -38,8 +37,8 @@ describe('runLogout', () => {
|
||||
reg.setHost('h1')
|
||||
reg.setAccount('a@x')
|
||||
reg.save()
|
||||
store.write('h1', 'a@x', 'dfoa_a')
|
||||
store.write('h1', 'b@x', 'dfoa_b')
|
||||
store.set({ key: 'tokens.h1.a@x', default: '' }, 'dfoa_a')
|
||||
store.set({ key: 'tokens.h1.b@x', default: '' }, 'dfoa_b')
|
||||
}
|
||||
|
||||
it('removes only the active context, keeps others, unsets pointers, file survives', async () => {
|
||||
@ -50,23 +49,12 @@ describe('runLogout', () => {
|
||||
expect(after?.hosts.h1?.accounts['a@x']).toBeUndefined()
|
||||
expect(after?.hosts.h1?.accounts['b@x']).toBeDefined()
|
||||
expect(after?.current_host).toBeUndefined()
|
||||
expect(store.read('h1', 'a@x')).toBe('')
|
||||
expect(store.read('h1', 'b@x')).toBe('dfoa_b')
|
||||
expect(store.get({ key: 'tokens.h1.a@x', default: '' })).toBe('')
|
||||
expect(store.get({ key: 'tokens.h1.b@x', default: '' })).toBe('dfoa_b')
|
||||
const raw = await readFile(join(dir, 'hosts.yml'), 'utf8')
|
||||
expect(raw).toContain('b@x')
|
||||
})
|
||||
|
||||
it('clears local credentials even when the store.read throws (e.g. keyring locked)', async () => {
|
||||
const store = new MemStore()
|
||||
seed(store)
|
||||
store.read = () => {
|
||||
throw new Error('keyring locked')
|
||||
}
|
||||
await runLogout({ io: bufferStreams(), reg: Registry.load(), store })
|
||||
const after = Registry.load()
|
||||
expect(after?.hosts.h1?.accounts['a@x']).toBeUndefined()
|
||||
})
|
||||
|
||||
it('throws NotLoggedIn when no active context', async () => {
|
||||
Registry.empty('file').save()
|
||||
await expect(runLogout({ io: bufferStreams(), reg: Registry.load(), store: new MemStore() }))
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import type { Registry } from '@/auth/hosts'
|
||||
import type { HttpClient } from '@/http/types'
|
||||
import type { TokenStore } from '@/store/token-store'
|
||||
import type { Store } from '@/store/store'
|
||||
import type { IOStreams } from '@/sys/io/streams'
|
||||
import { AccountSessionsClient } from '@/api/account-sessions'
|
||||
import { getTokenStore } from '@/store/manager'
|
||||
import { getTokenStore, tokenKey } from '@/store/manager'
|
||||
import { colorEnabled, colorScheme } from '@/sys/io/color'
|
||||
|
||||
export type LogoutOptions = {
|
||||
@ -11,7 +11,7 @@ export type LogoutOptions = {
|
||||
readonly reg: Registry
|
||||
readonly http?: HttpClient
|
||||
/** Optional override for tests; production resolves via `getTokenStore`. */
|
||||
readonly store?: TokenStore
|
||||
readonly store?: Store
|
||||
}
|
||||
|
||||
const REVOCABLE_PREFIXES = ['dfoa_', 'dfoe_'] as const
|
||||
@ -21,12 +21,8 @@ export async function runLogout(opts: LogoutOptions): Promise<void> {
|
||||
const reg = opts.reg
|
||||
const active = reg.requireActive()
|
||||
|
||||
const store = opts.store ?? getTokenStore(reg.token_storage)
|
||||
let bearer = ''
|
||||
try {
|
||||
bearer = store.read(active.host, active.email)
|
||||
}
|
||||
catch { /* keyring locked — skip remote revocation, local cleanup still runs */ }
|
||||
const store = opts.store ?? getTokenStore().store
|
||||
const bearer = store.get(tokenKey(active.host, active.email))
|
||||
|
||||
let revokeWarning = ''
|
||||
if (bearer !== '' && revokeAllowed(bearer) && opts.http !== undefined) {
|
||||
|
||||
@ -11,6 +11,7 @@ function active(): ActiveContext {
|
||||
ctx: {
|
||||
account: { id: 'acct-1', email: 'inviter@example.com', name: 'Inviter' },
|
||||
workspace: { id: '550e8400-e29b-41d4-a716-446655440000', name: 'Default', role: 'owner' },
|
||||
available_workspaces: [{ id: '550e8400-e29b-41d4-a716-446655440000', name: 'Default', role: 'owner' }],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -11,6 +11,7 @@ function active(): ActiveContext {
|
||||
ctx: {
|
||||
account: { id: 'acct-1', email: 'me@example.com', name: 'Me' },
|
||||
workspace: { id: '550e8400-e29b-41d4-a716-446655440000', name: 'Default', role: 'owner' },
|
||||
available_workspaces: [{ id: '550e8400-e29b-41d4-a716-446655440000', name: 'Default', role: 'owner' }],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,6 +19,10 @@ function active(): ActiveContext {
|
||||
ctx: {
|
||||
account: { id: 'acct-1', email: 't@d.ai', name: 'T' },
|
||||
workspace: { id: 'ws-1', name: 'Default', role: 'owner' },
|
||||
available_workspaces: [
|
||||
{ id: 'ws-1', name: 'Default', role: 'owner' },
|
||||
{ id: 'ws-2', name: 'Other', role: 'normal' },
|
||||
],
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -67,7 +71,7 @@ describe('runDescribeApp', () => {
|
||||
})
|
||||
|
||||
it('text: agent app shows Agent: true', async () => {
|
||||
const out = await render({ appId: 'app-4', workspace: '00000000-0000-0000-0000-000000000002' })
|
||||
const out = await render({ appId: 'app-4', workspace: 'ws-2' })
|
||||
expect(out).toContain('Agent:')
|
||||
expect(out).toContain('true')
|
||||
})
|
||||
|
||||
@ -13,6 +13,10 @@ const baseActive: ActiveContext = {
|
||||
ctx: {
|
||||
account: { id: 'acct-1', email: 'tester@dify.ai', name: 'Test Tester' },
|
||||
workspace: { id: '550e8400-e29b-41d4-a716-446655440000', name: 'Default', role: 'owner' },
|
||||
available_workspaces: [
|
||||
{ id: '550e8400-e29b-41d4-a716-446655440000', name: 'Default', role: 'owner' },
|
||||
{ id: '550e8400-e29b-41d4-a716-446655440001', name: 'Other', role: 'normal' },
|
||||
],
|
||||
},
|
||||
scheme: 'http',
|
||||
}
|
||||
|
||||
@ -114,7 +114,14 @@ function describeToEnvelope(desc: AppDescribeResponse, wsId: string, wsName: str
|
||||
function workspaceNameForId(active: ActiveContext, id: string): string {
|
||||
if (id === '')
|
||||
return ''
|
||||
return active.ctx.workspace?.id === id ? active.ctx.workspace.name : ''
|
||||
const ctx = active.ctx
|
||||
if (ctx.workspace?.id === id)
|
||||
return ctx.workspace.name
|
||||
for (const w of ctx.available_workspaces ?? []) {
|
||||
if (w.id === id)
|
||||
return w.name
|
||||
}
|
||||
return ''
|
||||
}
|
||||
|
||||
async function runAllWorkspaces(
|
||||
|
||||
@ -12,6 +12,7 @@ function active(): ActiveContext {
|
||||
ctx: {
|
||||
account: { id: 'acct-1', email: 'me@example.com', name: 'Me' },
|
||||
workspace: { id: '550e8400-e29b-41d4-a716-446655440000', name: 'Default', role: 'owner' },
|
||||
available_workspaces: [{ id: '550e8400-e29b-41d4-a716-446655440000', name: 'Default', role: 'owner' }],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@ function env(): WorkspaceListResponse {
|
||||
return {
|
||||
workspaces: [
|
||||
{ id: 'ws-1', name: 'Default', role: 'owner', status: 'normal', current: true },
|
||||
{ id: '00000000-0000-0000-0000-000000000002', name: 'Other', role: 'normal', status: 'normal', current: false },
|
||||
{ id: 'ws-2', name: 'Other', role: 'normal', status: 'normal', current: false },
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
@ -13,6 +13,10 @@ const baseActive: ActiveContext = {
|
||||
ctx: {
|
||||
account: { id: 'acct-1', email: 'tester@dify.ai', name: 'Test Tester' },
|
||||
workspace: { id: '550e8400-e29b-41d4-a716-446655440000', name: 'Default', role: 'owner' },
|
||||
available_workspaces: [
|
||||
{ id: '550e8400-e29b-41d4-a716-446655440000', name: 'Default', role: 'owner' },
|
||||
{ id: '550e8400-e29b-41d4-a716-446655440001', name: 'Other', role: 'normal' },
|
||||
],
|
||||
},
|
||||
scheme: 'http',
|
||||
}
|
||||
|
||||
@ -20,6 +20,10 @@ function active(): ActiveContext {
|
||||
ctx: {
|
||||
account: { id: 'acct-1', email: 't@d.ai', name: 'T' },
|
||||
workspace: { id: 'ws-1', name: 'Default', role: 'owner' },
|
||||
available_workspaces: [
|
||||
{ id: 'ws-1', name: 'Default', role: 'owner' },
|
||||
{ id: 'ws-2', name: 'Other', role: 'normal' },
|
||||
],
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -135,7 +139,7 @@ describe('runApp', () => {
|
||||
const io = bufferStreams()
|
||||
const cache = await loadAppInfoCache({ store: getCache(CACHE_APP_INFO) })
|
||||
await runApp(
|
||||
{ appId: 'app-4', workspace: '00000000-0000-0000-0000-000000000002', message: 'do research' },
|
||||
{ appId: 'app-4', workspace: 'ws-2', message: 'do research' },
|
||||
{ active: active(), http: testHttpClient(mock.url, 'dfoa_test'), host: mock.url, io, cache },
|
||||
)
|
||||
expect(io.outBuf()).toContain('do research')
|
||||
@ -146,7 +150,7 @@ describe('runApp', () => {
|
||||
const io = bufferStreams()
|
||||
const cache = await loadAppInfoCache({ store: getCache(CACHE_APP_INFO) })
|
||||
await runApp(
|
||||
{ appId: 'app-4', workspace: '00000000-0000-0000-0000-000000000002', message: 'go', stream: true },
|
||||
{ appId: 'app-4', workspace: 'ws-2', message: 'go', stream: true },
|
||||
{ active: active(), http: testHttpClient(mock.url, 'dfoa_test'), host: mock.url, io, cache },
|
||||
)
|
||||
expect(io.outBuf()).toContain('go')
|
||||
|
||||
@ -11,6 +11,7 @@ function active(): ActiveContext {
|
||||
ctx: {
|
||||
account: { id: 'acct-1', email: 'me@example.com', name: 'Me' },
|
||||
workspace: { id: '550e8400-e29b-41d4-a716-446655440000', name: 'Default', role: 'owner' },
|
||||
available_workspaces: [{ id: '550e8400-e29b-41d4-a716-446655440000', name: 'Default', role: 'owner' }],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { TokenStore } from '@/store/token-store'
|
||||
import type { Key, Store } from '@/store/store'
|
||||
import { mkdtemp, rm } from 'node:fs/promises'
|
||||
import { tmpdir } from 'node:os'
|
||||
import { join } from 'node:path'
|
||||
@ -8,13 +8,12 @@ import { ENV_CONFIG_DIR } from '@/store/dir'
|
||||
import { bufferStreams } from '@/sys/io/streams'
|
||||
import { runUseAccount } from './use-account'
|
||||
|
||||
function memStore(seed: Record<string, string>): TokenStore {
|
||||
const k = (host: string, email: string): string => `${host} ${email}`
|
||||
const m = new Map<string, string>(Object.entries(seed))
|
||||
function memStore(seed: Record<string, string>): Store {
|
||||
const m = new Map<string, unknown>(Object.entries(seed))
|
||||
return {
|
||||
read(host: string, email: string): string { return m.get(k(host, email)) ?? '' },
|
||||
write(host: string, email: string, bearer: string): void { m.set(k(host, email), bearer) },
|
||||
remove(host: string, email: string): void { m.delete(k(host, email)) },
|
||||
get<T>(k: Key<T>): T { return (m.get(k.key) as T | undefined) ?? k.default },
|
||||
set<T>(k: Key<T>, v: T): void { m.set(k.key, v) },
|
||||
unset<T>(k: Key<T>): void { m.delete(k.key) },
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,7 +39,7 @@ describe('runUseAccount', () => {
|
||||
})
|
||||
|
||||
it('switches current_account when email valid + token present', async () => {
|
||||
await runUseAccount({ io: bufferStreams(), email: 'b@x', store: memStore({ 'h1 b@x': 'dfoa_b' }) })
|
||||
await runUseAccount({ io: bufferStreams(), email: 'b@x', store: memStore({ 'tokens.h1.b@x': 'dfoa_b' }) })
|
||||
expect(Registry.load().hosts.h1?.current_account).toBe('b@x')
|
||||
})
|
||||
|
||||
@ -51,7 +50,7 @@ describe('runUseAccount', () => {
|
||||
})
|
||||
|
||||
it('errors when the email is unknown on the current host', async () => {
|
||||
await expect(runUseAccount({ io: bufferStreams(), email: 'z@x', store: memStore({ 'h1 z@x': 'x' }) }))
|
||||
await expect(runUseAccount({ io: bufferStreams(), email: 'z@x', store: memStore({ 'tokens.h1.z@x': 'x' }) }))
|
||||
.rejects
|
||||
.toThrow(/unknown account|no account/i)
|
||||
})
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import type { HostEntry } from '@/auth/hosts'
|
||||
import type { TokenStore } from '@/store/token-store'
|
||||
import type { Store } from '@/store/store'
|
||||
import type { IOStreams } from '@/sys/io/streams'
|
||||
import { notLoggedInError, Registry } from '@/auth/hosts'
|
||||
import { BaseError } from '@/errors/base'
|
||||
import { ErrorCode } from '@/errors/codes'
|
||||
import { getTokenStore } from '@/store/manager'
|
||||
import { getTokenStore, tokenKey } from '@/store/manager'
|
||||
import { colorEnabled, colorScheme } from '@/sys/io/color'
|
||||
import { selectFromList } from '@/sys/io/select'
|
||||
|
||||
@ -12,7 +12,7 @@ export type UseAccountOptions = {
|
||||
readonly io: IOStreams
|
||||
readonly email: string | undefined
|
||||
/** Optional override for tests; production resolves via `getTokenStore`. */
|
||||
readonly store?: TokenStore
|
||||
readonly store?: Store
|
||||
}
|
||||
|
||||
type AccountChoice = { email: string, name: string, sso: boolean, active: boolean }
|
||||
@ -38,8 +38,8 @@ export async function runUseAccount(opts: UseAccountOptions): Promise<void> {
|
||||
})
|
||||
}
|
||||
|
||||
const store = opts.store ?? getTokenStore(reg.token_storage)
|
||||
if (store.read(host, target) === '') {
|
||||
const store = opts.store ?? getTokenStore().store
|
||||
if (store.get(tokenKey(host, target)) === '') {
|
||||
throw new BaseError({
|
||||
code: ErrorCode.NotLoggedIn,
|
||||
message: `no credential stored for ${target} on ${host}`,
|
||||
|
||||
@ -5,17 +5,16 @@ import { Args } from '@/framework/flags'
|
||||
import { runUseWorkspace } from './use'
|
||||
|
||||
export default class UseWorkspace extends DifyCommand {
|
||||
static override description = 'Switch the active workspace on the server (omit the id to pick interactively)'
|
||||
static override description = 'Switch the active workspace on the server and refresh hosts.yml'
|
||||
|
||||
static override effect: CommandEffect = 'write'
|
||||
|
||||
static override examples = [
|
||||
'<%= config.bin %> use workspace ws-abc123',
|
||||
'<%= config.bin %> use workspace',
|
||||
]
|
||||
|
||||
static override args = {
|
||||
workspaceId: Args.string({ description: 'workspace id to switch to (omit to pick interactively)', required: false }),
|
||||
workspaceId: Args.string({ description: 'workspace id to switch to', required: true }),
|
||||
}
|
||||
|
||||
static override flags = {
|
||||
|
||||
@ -10,21 +10,18 @@ import { join } from 'node:path'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { Registry } from '@/auth/hosts'
|
||||
import { ENV_CONFIG_DIR } from '@/store/dir'
|
||||
import { selectFromList } from '@/sys/io/select'
|
||||
import { bufferStreams } from '@/sys/io/streams'
|
||||
import { runUseWorkspace } from './use.js'
|
||||
|
||||
vi.mock('@/sys/io/select', () => ({
|
||||
selectFromList: vi.fn(),
|
||||
}))
|
||||
|
||||
const selectFromListMock = vi.mocked(selectFromList)
|
||||
|
||||
function makeRegistry(): Registry {
|
||||
const reg = Registry.empty('file')
|
||||
reg.upsert('cloud.dify.ai', 'tester@dify.ai', {
|
||||
account: { id: 'acct-1', email: 'tester@dify.ai', name: 'Tester' },
|
||||
workspace: { id: 'ws-1', name: 'Default', role: 'owner' },
|
||||
available_workspaces: [
|
||||
{ id: 'ws-1', name: 'Default', role: 'owner' },
|
||||
{ id: 'ws-2', name: 'Stale Name', role: 'normal' },
|
||||
],
|
||||
})
|
||||
reg.setHost('cloud.dify.ai')
|
||||
reg.setAccount('tester@dify.ai')
|
||||
@ -38,28 +35,23 @@ function makeActive(reg: Registry): ActiveContext {
|
||||
return active
|
||||
}
|
||||
|
||||
function makeDetail(over: Partial<WorkspaceDetailResponse> = {}): WorkspaceDetailResponse {
|
||||
return {
|
||||
id: '00000000-0000-0000-0000-000000000002',
|
||||
name: 'Two',
|
||||
role: 'owner',
|
||||
status: 'normal',
|
||||
current: true,
|
||||
created_at: '2026-05-18T00:00:00Z',
|
||||
...over,
|
||||
}
|
||||
}
|
||||
|
||||
function fakeClient(opts: {
|
||||
switch?: () => Promise<WorkspaceDetailResponse>
|
||||
list?: () => Promise<WorkspaceListResponse>
|
||||
}) {
|
||||
return {
|
||||
switch: vi.fn(opts.switch ?? (() => Promise.resolve(makeDetail()))),
|
||||
switch: vi.fn(opts.switch ?? (() => Promise.resolve({
|
||||
id: 'ws-2',
|
||||
name: 'Switched',
|
||||
role: 'normal',
|
||||
status: 'normal',
|
||||
current: true,
|
||||
created_at: '2026-05-18T00:00:00Z',
|
||||
}))),
|
||||
list: vi.fn(opts.list ?? (() => Promise.resolve({
|
||||
workspaces: [
|
||||
{ id: 'ws-1', name: 'Default', role: 'owner', status: 'normal', current: true },
|
||||
{ id: '00000000-0000-0000-0000-000000000002', name: 'Two', role: 'owner', status: 'normal', current: false },
|
||||
{ id: 'ws-1', name: 'Default', role: 'owner', status: 'normal', current: false },
|
||||
{ id: 'ws-2', name: 'Switched', role: 'normal', status: 'normal', current: true },
|
||||
],
|
||||
}))),
|
||||
}
|
||||
@ -67,13 +59,12 @@ function fakeClient(opts: {
|
||||
|
||||
describe('runUseWorkspace', () => {
|
||||
let configDir: string
|
||||
let prevConfigDir: string | undefined
|
||||
|
||||
let prevConfigDir: string | undefined
|
||||
beforeEach(async () => {
|
||||
configDir = await mkdtemp(join(tmpdir(), 'difyctl-use-workspace-'))
|
||||
prevConfigDir = process.env[ENV_CONFIG_DIR]
|
||||
process.env[ENV_CONFIG_DIR] = configDir
|
||||
selectFromListMock.mockReset()
|
||||
})
|
||||
afterEach(async () => {
|
||||
if (prevConfigDir === undefined)
|
||||
@ -83,7 +74,7 @@ describe('runUseWorkspace', () => {
|
||||
await rm(configDir, { recursive: true, force: true })
|
||||
})
|
||||
|
||||
it('arg path: switches directly without listing and persists only the active workspace', async () => {
|
||||
it('happy path: POST /switch → GET /workspaces → write hosts.yml', async () => {
|
||||
const io = bufferStreams()
|
||||
const reg = makeRegistry()
|
||||
reg.save()
|
||||
@ -91,7 +82,7 @@ describe('runUseWorkspace', () => {
|
||||
const client = fakeClient({})
|
||||
|
||||
const next = await runUseWorkspace(
|
||||
{ workspaceId: '00000000-0000-0000-0000-000000000002' },
|
||||
{ workspaceId: 'ws-2' },
|
||||
{
|
||||
reg,
|
||||
active,
|
||||
@ -101,42 +92,62 @@ describe('runUseWorkspace', () => {
|
||||
},
|
||||
)
|
||||
|
||||
expect(client.switch).toHaveBeenCalledExactlyOnceWith('00000000-0000-0000-0000-000000000002')
|
||||
expect(client.list).not.toHaveBeenCalled()
|
||||
expect(client.switch).toHaveBeenCalledExactlyOnceWith('ws-2')
|
||||
expect(client.list).toHaveBeenCalledOnce()
|
||||
|
||||
const activeCtx = next.resolveActive()
|
||||
expect(activeCtx?.ctx.workspace).toEqual({ id: '00000000-0000-0000-0000-000000000002', name: 'Two', role: 'owner' })
|
||||
expect((activeCtx?.ctx as Record<string, unknown> | undefined)?.available_workspaces).toBeUndefined()
|
||||
expect(activeCtx?.ctx.workspace).toEqual({ id: 'ws-2', name: 'Switched', role: 'normal' })
|
||||
expect(activeCtx?.ctx.available_workspaces).toEqual([
|
||||
{ id: 'ws-1', name: 'Default', role: 'owner' },
|
||||
{ id: 'ws-2', name: 'Switched', role: 'normal' },
|
||||
])
|
||||
|
||||
const reloaded = Registry.load()
|
||||
const reloadedActive = reloaded?.resolveActive()
|
||||
expect(reloadedActive?.ctx.workspace?.id).toBe('00000000-0000-0000-0000-000000000002')
|
||||
expect(reloadedActive?.ctx.workspace?.name).toBe('Two')
|
||||
expect((reloadedActive?.ctx as Record<string, unknown> | undefined)?.available_workspaces).toBeUndefined()
|
||||
expect(reloadedActive?.ctx.workspace?.id).toBe('ws-2')
|
||||
expect(reloadedActive?.ctx.workspace?.name).toBe('Switched')
|
||||
|
||||
expect(io.outBuf()).toMatch(/Switched to Two \(00000000-0000-0000-0000-000000000002\)/)
|
||||
expect(io.outBuf()).toMatch(/Switched to Switched \(ws-2\)/)
|
||||
})
|
||||
|
||||
it('no-arg + no-TTY: rejects with usage_missing_arg and never switches', async () => {
|
||||
it('hosts.yml contains no bearer after switch', async () => {
|
||||
const io = bufferStreams()
|
||||
io.isErrTTY = false
|
||||
const reg = makeRegistry()
|
||||
reg.save()
|
||||
const active = makeActive(reg)
|
||||
const client = fakeClient({})
|
||||
|
||||
await expect(
|
||||
runUseWorkspace(
|
||||
{ workspaceId: undefined },
|
||||
{ reg, active, http: {} as HttpClient, io, workspacesFactory: () => client as never },
|
||||
),
|
||||
).rejects.toMatchObject({ code: 'usage_missing_arg' })
|
||||
await runUseWorkspace(
|
||||
{ workspaceId: 'ws-2' },
|
||||
{ reg, active, http: {} as HttpClient, io, workspacesFactory: () => client as never },
|
||||
)
|
||||
|
||||
expect(client.switch).not.toHaveBeenCalled()
|
||||
expect(client.list).not.toHaveBeenCalled()
|
||||
const reloaded = Registry.load()
|
||||
const raw = JSON.stringify(reloaded)
|
||||
expect(raw).not.toMatch(/bearer/)
|
||||
})
|
||||
|
||||
it('switch failure: rejects and leaves the active workspace untouched', async () => {
|
||||
it('refreshes stale workspace name from server', async () => {
|
||||
// registry has ws-2 named "Stale Name"; server returns "Switched".
|
||||
// We expect saveRegistry to record the fresh name from the server.
|
||||
const io = bufferStreams()
|
||||
const reg = makeRegistry()
|
||||
reg.save()
|
||||
const active = makeActive(reg)
|
||||
const client = fakeClient({})
|
||||
|
||||
await runUseWorkspace(
|
||||
{ workspaceId: 'ws-2' },
|
||||
{ reg, active, http: {} as HttpClient, io, workspacesFactory: () => client as never },
|
||||
)
|
||||
|
||||
const reloaded = Registry.load()
|
||||
const reloadedActive = reloaded?.resolveActive()
|
||||
expect(reloadedActive?.ctx.workspace?.name).toBe('Switched')
|
||||
expect(reloadedActive?.ctx.available_workspaces?.find(w => w.id === 'ws-2')?.name).toBe('Switched')
|
||||
})
|
||||
|
||||
it('does NOT mutate hosts.yml when POST /switch fails', async () => {
|
||||
const io = bufferStreams()
|
||||
const reg = makeRegistry()
|
||||
reg.save()
|
||||
@ -149,41 +160,85 @@ describe('runUseWorkspace', () => {
|
||||
|
||||
await expect(
|
||||
runUseWorkspace(
|
||||
{ workspaceId: '00000000-0000-0000-0000-000000000002' },
|
||||
{ reg, active, http: {} as HttpClient, io, workspacesFactory: () => client as never },
|
||||
{ workspaceId: 'ws-2' },
|
||||
{
|
||||
reg,
|
||||
active,
|
||||
http: {} as HttpClient,
|
||||
io,
|
||||
workspacesFactory: () => client as never,
|
||||
},
|
||||
),
|
||||
).rejects.toThrow(/forbidden/)
|
||||
|
||||
expect(client.list).not.toHaveBeenCalled()
|
||||
const after = Registry.load()
|
||||
expect(after).toEqual(before)
|
||||
expect(after?.resolveActive()?.ctx.workspace?.id).toBe('ws-1')
|
||||
const afterActive = after?.resolveActive()
|
||||
expect(afterActive?.ctx.workspace?.id).toBe('ws-1')
|
||||
})
|
||||
|
||||
it('picker path (TTY): lists live workspaces and switches to the selected one', async () => {
|
||||
it('does NOT mutate hosts.yml when GET /workspaces fails after switch', async () => {
|
||||
const io = bufferStreams()
|
||||
io.isErrTTY = true
|
||||
const reg = makeRegistry()
|
||||
reg.save()
|
||||
const active = makeActive(reg)
|
||||
const client = fakeClient({})
|
||||
const before = Registry.load()
|
||||
|
||||
selectFromListMock.mockResolvedValue({ id: '00000000-0000-0000-0000-000000000002', name: 'Two', role: 'owner' })
|
||||
const client = fakeClient({
|
||||
list: () => Promise.reject(new Error('transient list failure')),
|
||||
})
|
||||
|
||||
await runUseWorkspace(
|
||||
{ workspaceId: undefined },
|
||||
{ reg, active, http: {} as HttpClient, io, workspacesFactory: () => client as never },
|
||||
)
|
||||
await expect(
|
||||
runUseWorkspace(
|
||||
{ workspaceId: 'ws-2' },
|
||||
{
|
||||
reg,
|
||||
active,
|
||||
http: {} as HttpClient,
|
||||
io,
|
||||
workspacesFactory: () => client as never,
|
||||
},
|
||||
),
|
||||
).rejects.toThrow(/transient list failure/)
|
||||
|
||||
expect(client.list).toHaveBeenCalledOnce()
|
||||
expect(selectFromListMock).toHaveBeenCalledOnce()
|
||||
const passed = selectFromListMock.mock.calls[0]![0]
|
||||
expect(passed.items).toEqual([
|
||||
{ id: 'ws-1', name: 'Default', role: 'owner' },
|
||||
{ id: '00000000-0000-0000-0000-000000000002', name: 'Two', role: 'owner' },
|
||||
])
|
||||
expect(client.switch).toHaveBeenCalledExactlyOnceWith('00000000-0000-0000-0000-000000000002')
|
||||
const after = Registry.load()
|
||||
expect(after).toEqual(before)
|
||||
})
|
||||
|
||||
const reloadedActive = Registry.load()?.resolveActive()
|
||||
expect(reloadedActive?.ctx.workspace?.id).toBe('00000000-0000-0000-0000-000000000002')
|
||||
it('throws when server returns switch=<id> but id is missing from /workspaces list', async () => {
|
||||
const io = bufferStreams()
|
||||
const reg = makeRegistry()
|
||||
reg.save()
|
||||
const active = makeActive(reg)
|
||||
|
||||
const client = fakeClient({
|
||||
switch: () => Promise.resolve({
|
||||
id: 'ws-7',
|
||||
name: 'Ghost',
|
||||
role: 'normal',
|
||||
status: 'normal',
|
||||
current: true,
|
||||
created_at: null as unknown as string,
|
||||
}),
|
||||
list: () => Promise.resolve({
|
||||
workspaces: [
|
||||
{ id: 'ws-1', name: 'Default', role: 'owner', status: 'normal', current: false },
|
||||
],
|
||||
}),
|
||||
})
|
||||
|
||||
await expect(
|
||||
runUseWorkspace(
|
||||
{ workspaceId: 'ws-7' },
|
||||
{
|
||||
reg,
|
||||
active,
|
||||
http: {} as HttpClient,
|
||||
io,
|
||||
workspacesFactory: () => client as never,
|
||||
},
|
||||
),
|
||||
).rejects.toThrow(/not visible in \/workspaces/)
|
||||
})
|
||||
})
|
||||
|
||||
@ -5,11 +5,10 @@ import { WorkspacesClient } from '@/api/workspaces'
|
||||
import { BaseError } from '@/errors/base'
|
||||
import { ErrorCode } from '@/errors/codes'
|
||||
import { colorEnabled, colorScheme } from '@/sys/io/color'
|
||||
import { selectFromList } from '@/sys/io/select'
|
||||
import { runWithSpinner } from '@/sys/io/spinner'
|
||||
|
||||
export type UseWorkspaceOptions = {
|
||||
readonly workspaceId?: string
|
||||
readonly workspaceId: string
|
||||
}
|
||||
|
||||
export type UseWorkspaceDeps = {
|
||||
@ -23,12 +22,16 @@ export type UseWorkspaceDeps = {
|
||||
/**
|
||||
* Switch the caller's active workspace.
|
||||
*
|
||||
* With an explicit id we switch directly; with no id we fetch the live
|
||||
* workspace list and let the caller pick one interactively (TTY only).
|
||||
*
|
||||
* The server-side switch is the source of truth: if POST
|
||||
* `/workspaces/<id>/switch` fails we abort before touching `hosts.yml`, so
|
||||
* local state never diverges from the server.
|
||||
* Strict ordering:
|
||||
* 1. POST /workspaces/<id>/switch — if this fails (403/404/etc.) we abort
|
||||
* with no `hosts.yml` mutation, so local state never diverges from the
|
||||
* server. Any fallback to a pure-local update is explicitly disallowed
|
||||
* (see workspace-plan.md decision D4).
|
||||
* 2. GET /workspaces — refresh the membership list so `available_workspaces`
|
||||
* stays in sync. Failure here also aborts; the server-side current has
|
||||
* already moved, but the local file is left untouched. A follow-up
|
||||
* `difyctl get workspace` will reconcile.
|
||||
* 3. Persist `workspace` + `available_workspaces` atomically via `saveRegistry`.
|
||||
*/
|
||||
export async function runUseWorkspace(
|
||||
opts: UseWorkspaceOptions,
|
||||
@ -38,51 +41,32 @@ export async function runUseWorkspace(
|
||||
const factory = deps.workspacesFactory ?? ((h: HttpClient) => new WorkspacesClient(h))
|
||||
const client = factory(deps.http)
|
||||
|
||||
const argId = opts.workspaceId?.trim() ?? ''
|
||||
const id = argId !== '' ? argId : await pickWorkspaceId(client, deps)
|
||||
|
||||
const detail = await runWithSpinner(
|
||||
{ io: deps.io, label: `Switching to ${id}` },
|
||||
() => client.switch(id),
|
||||
{ io: deps.io, label: `Switching to ${opts.workspaceId}` },
|
||||
() => client.switch(opts.workspaceId),
|
||||
)
|
||||
|
||||
const list = await runWithSpinner(
|
||||
{ io: deps.io, label: 'Refreshing workspaces' },
|
||||
() => client.list(),
|
||||
)
|
||||
|
||||
const matched = list.workspaces.find(w => w.id === detail.id)
|
||||
if (matched === undefined) {
|
||||
throw new BaseError({
|
||||
code: ErrorCode.Unknown,
|
||||
message: `server returned switch=${detail.id} but it is not visible in /workspaces`,
|
||||
hint: 'try again or contact your workspace admin',
|
||||
})
|
||||
}
|
||||
|
||||
const nextCtx = {
|
||||
...deps.active.ctx,
|
||||
workspace: { id: detail.id, name: detail.name, role: detail.role },
|
||||
workspace: { id: matched.id, name: matched.name, role: matched.role },
|
||||
available_workspaces: list.workspaces.map<Workspace>(w => ({ id: w.id, name: w.name, role: w.role })),
|
||||
}
|
||||
deps.reg.upsert(deps.active.host, deps.active.email, nextCtx)
|
||||
deps.reg.save()
|
||||
deps.io.out.write(`${cs.successIcon()} Switched to ${detail.name} (${detail.id})\n`)
|
||||
deps.io.out.write(`${cs.successIcon()} Switched to ${matched.name} (${matched.id})\n`)
|
||||
return deps.reg
|
||||
}
|
||||
|
||||
async function pickWorkspaceId(client: WorkspacesClient, deps: UseWorkspaceDeps): Promise<string> {
|
||||
if (!deps.io.isErrTTY) {
|
||||
throw new BaseError({
|
||||
code: ErrorCode.UsageMissingArg,
|
||||
message: 'a workspace id is required (no TTY)',
|
||||
hint: 'pass the id: \'difyctl use workspace <id>\'',
|
||||
})
|
||||
}
|
||||
|
||||
const list = await runWithSpinner(
|
||||
{ io: deps.io, label: 'Loading workspaces' },
|
||||
() => client.list(),
|
||||
)
|
||||
const items = list.workspaces.map<Workspace>(w => ({ id: w.id, name: w.name, role: w.role }))
|
||||
if (items.length === 0) {
|
||||
throw new BaseError({
|
||||
code: ErrorCode.AccessDenied,
|
||||
message: 'no workspaces available to switch to',
|
||||
})
|
||||
}
|
||||
|
||||
const activeId = deps.active.ctx.workspace?.id
|
||||
const picked = await selectFromList<Workspace>({
|
||||
io: deps.io,
|
||||
items,
|
||||
header: 'Select a workspace',
|
||||
render: w => `${w.id === activeId ? '* ' : ' '}${w.name} (${w.role})`,
|
||||
})
|
||||
return picked.id
|
||||
}
|
||||
|
||||
@ -17,7 +17,6 @@ export const ErrorCode = {
|
||||
ClientError: 'client_error',
|
||||
Unknown: 'unknown',
|
||||
IllegalArgumentError: 'illegal_argument',
|
||||
KeyringUnavailable: 'keyring_unavailable',
|
||||
} as const
|
||||
|
||||
export type ErrorCodeValue = (typeof ErrorCode)[keyof typeof ErrorCode]
|
||||
@ -51,7 +50,6 @@ const CODE_TO_EXIT: Readonly<Record<ErrorCodeValue, ExitCodeValue>> = {
|
||||
client_error: ExitCode.Generic,
|
||||
unknown: ExitCode.Generic,
|
||||
illegal_argument: ExitCode.Usage,
|
||||
keyring_unavailable: ExitCode.Generic,
|
||||
}
|
||||
|
||||
export function exitFor(code: string): ExitCodeValue {
|
||||
|
||||
@ -1,138 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { BaseError } from '@/errors/base'
|
||||
import { ErrorCode } from '@/errors/codes'
|
||||
|
||||
type EntryArgs = { service: string, username: string }
|
||||
|
||||
const passwords = new Map<string, string>()
|
||||
const constructed: EntryArgs[] = []
|
||||
let getPasswordError: Error | null = null
|
||||
let setPasswordError: Error | null = null
|
||||
|
||||
class FakeEntry {
|
||||
private readonly key: string
|
||||
constructor(service: string, username: string) {
|
||||
constructed.push({ service, username })
|
||||
this.key = `${service}::${username}`
|
||||
}
|
||||
|
||||
setPassword(value: string): void {
|
||||
if (setPasswordError !== null)
|
||||
throw setPasswordError
|
||||
passwords.set(this.key, value)
|
||||
}
|
||||
|
||||
getPassword(): string | null {
|
||||
if (getPasswordError !== null)
|
||||
throw getPasswordError
|
||||
return passwords.get(this.key) ?? null
|
||||
}
|
||||
|
||||
deletePassword(): boolean {
|
||||
if (!passwords.has(this.key))
|
||||
return false
|
||||
passwords.delete(this.key)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
vi.mock('@napi-rs/keyring', () => ({
|
||||
Entry: FakeEntry,
|
||||
}))
|
||||
|
||||
const { KeychainTokenStore } = await import('./token-store')
|
||||
|
||||
const SERVICE = 'difyctl-test'
|
||||
|
||||
beforeEach(() => {
|
||||
passwords.clear()
|
||||
constructed.length = 0
|
||||
getPasswordError = null
|
||||
setPasswordError = null
|
||||
})
|
||||
|
||||
describe('KeychainTokenStore', () => {
|
||||
it('round-trips a bearer through write/read', () => {
|
||||
const store = new KeychainTokenStore(SERVICE)
|
||||
store.write('https://cloud.dify.ai', 'a@x.com', 'dfoa_secret')
|
||||
expect(store.read('https://cloud.dify.ai', 'a@x.com')).toBe('dfoa_secret')
|
||||
})
|
||||
|
||||
it('returns empty string for an absent credential', () => {
|
||||
const store = new KeychainTokenStore(SERVICE)
|
||||
expect(store.read('https://cloud.dify.ai', 'missing@x.com')).toBe('')
|
||||
})
|
||||
|
||||
it('removes a credential, after which read returns empty string', () => {
|
||||
const store = new KeychainTokenStore(SERVICE)
|
||||
store.write('h', 'e', 'dfoa_secret')
|
||||
store.remove('h', 'e')
|
||||
expect(store.read('h', 'e')).toBe('')
|
||||
})
|
||||
|
||||
it('treats remove of an absent credential as a no-op', () => {
|
||||
const store = new KeychainTokenStore(SERVICE)
|
||||
expect(() => store.remove('h', 'absent')).not.toThrow()
|
||||
})
|
||||
|
||||
it('uses the legacy entry name tokens.<host>.<email> (back-compat)', () => {
|
||||
const store = new KeychainTokenStore(SERVICE)
|
||||
store.write('https://cloud.dify.ai', 'a@x.com', 'dfoa_secret')
|
||||
expect(constructed).toContainEqual({
|
||||
service: SERVICE,
|
||||
username: 'tokens.https://cloud.dify.ai.a@x.com',
|
||||
})
|
||||
})
|
||||
|
||||
it('keeps host and email literal — dots, colons, and @ are never split', () => {
|
||||
const store = new KeychainTokenStore(SERVICE)
|
||||
const host = 'https://my.dify.example.com:8443'
|
||||
const email = 'first.last@sub.example.com'
|
||||
store.write(host, email, 'dfoa_literal')
|
||||
expect(store.read(host, email)).toBe('dfoa_literal')
|
||||
expect(constructed).toContainEqual({
|
||||
service: SERVICE,
|
||||
username: `tokens.${host}.${email}`,
|
||||
})
|
||||
})
|
||||
|
||||
it('returns empty string when the stored value decodes to a non-string', () => {
|
||||
const store = new KeychainTokenStore(SERVICE)
|
||||
passwords.set(`${SERVICE}::tokens.h.e`, '123')
|
||||
expect(store.read('h', 'e')).toBe('')
|
||||
})
|
||||
|
||||
it('returns empty string when the stored value is not valid JSON', () => {
|
||||
const store = new KeychainTokenStore(SERVICE)
|
||||
passwords.set(`${SERVICE}::tokens.h.e`, 'not-json')
|
||||
expect(store.read('h', 'e')).toBe('')
|
||||
})
|
||||
|
||||
it('throws KeyringUnavailable (not empty string) when keyring access fails on read', () => {
|
||||
getPasswordError = new Error('keyring locked')
|
||||
const store = new KeychainTokenStore(SERVICE)
|
||||
let caught: unknown
|
||||
try {
|
||||
store.read('h', 'e')
|
||||
}
|
||||
catch (err) {
|
||||
caught = err
|
||||
}
|
||||
expect(caught).toBeInstanceOf(BaseError)
|
||||
expect((caught as BaseError).code).toBe(ErrorCode.KeyringUnavailable)
|
||||
})
|
||||
|
||||
it('throws KeyringUnavailable when keyring access fails on write', () => {
|
||||
setPasswordError = new Error('keyring locked')
|
||||
const store = new KeychainTokenStore(SERVICE)
|
||||
let caught: unknown
|
||||
try {
|
||||
store.write('h', 'e', 'dfoa_secret')
|
||||
}
|
||||
catch (err) {
|
||||
caught = err
|
||||
}
|
||||
expect(caught).toBeInstanceOf(BaseError)
|
||||
expect((caught as BaseError).code).toBe(ErrorCode.KeyringUnavailable)
|
||||
})
|
||||
})
|
||||
@ -1,29 +1,28 @@
|
||||
import type { TokenStore } from './token-store'
|
||||
import type { Key, Store } from './store'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { detectTokenStore, getTokenStore } from './manager'
|
||||
import { getTokenStore } from './manager'
|
||||
|
||||
function memStore(label: string): TokenStore & { _label: string } {
|
||||
const map = new Map<string, string>()
|
||||
const k = (h: string, e: string): string => `${h} ${e}`
|
||||
function memStore(label: string): Store & { _label: string } {
|
||||
const map = new Map<string, unknown>()
|
||||
return {
|
||||
_label: label,
|
||||
read(host: string, email: string): string {
|
||||
return map.get(k(host, email)) ?? ''
|
||||
get<T>(key: Key<T>): T {
|
||||
return (map.get(key.key) as T | undefined) ?? key.default
|
||||
},
|
||||
write(host: string, email: string, bearer: string): void {
|
||||
map.set(k(host, email), bearer)
|
||||
set<T>(key: Key<T>, value: T): void {
|
||||
map.set(key.key, value)
|
||||
},
|
||||
remove(host: string, email: string): void {
|
||||
map.delete(k(host, email))
|
||||
unset<T>(key: Key<T>): void {
|
||||
map.delete(key.key)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
describe('detectTokenStore', () => {
|
||||
describe('getTokenStore', () => {
|
||||
it('returns keychain store when probe succeeds', () => {
|
||||
const k = memStore('keyring')
|
||||
const f = memStore('file')
|
||||
const result = detectTokenStore({
|
||||
const result = getTokenStore({
|
||||
factory: { keyring: () => k, file: () => f },
|
||||
})
|
||||
expect(result.mode).toBe('keychain')
|
||||
@ -33,10 +32,12 @@ describe('detectTokenStore', () => {
|
||||
it('falls back to file when keyring set throws', () => {
|
||||
const k = memStore('keyring')
|
||||
const f = memStore('file')
|
||||
k.write = vi.fn(() => {
|
||||
throw new Error('locked')
|
||||
})
|
||||
const result = detectTokenStore({
|
||||
k.set = vi.fn(
|
||||
() => {
|
||||
throw new Error('locked')
|
||||
},
|
||||
)
|
||||
const result = getTokenStore({
|
||||
factory: { keyring: () => k, file: () => f },
|
||||
})
|
||||
expect(result.mode).toBe('file')
|
||||
@ -46,8 +47,8 @@ describe('detectTokenStore', () => {
|
||||
it('falls back to file when probe round-trip mismatches', () => {
|
||||
const k = memStore('keyring')
|
||||
const f = memStore('file')
|
||||
k.read = vi.fn(() => 'something-else') as TokenStore['read']
|
||||
const result = detectTokenStore({
|
||||
k.get = vi.fn(() => 'something-else') as Store['get']
|
||||
const result = getTokenStore({
|
||||
factory: { keyring: () => k, file: () => f },
|
||||
})
|
||||
expect(result.mode).toBe('file')
|
||||
@ -56,7 +57,7 @@ describe('detectTokenStore', () => {
|
||||
|
||||
it('falls back to file when keyring constructor throws', () => {
|
||||
const f = memStore('file')
|
||||
const result = detectTokenStore({
|
||||
const result = getTokenStore({
|
||||
factory: {
|
||||
keyring: () => { throw new Error('no backend') },
|
||||
file: () => f,
|
||||
@ -69,48 +70,9 @@ describe('detectTokenStore', () => {
|
||||
it('cleans up probe entry after successful probe', () => {
|
||||
const k = memStore('keyring')
|
||||
const f = memStore('file')
|
||||
detectTokenStore({
|
||||
getTokenStore({
|
||||
factory: { keyring: () => k, file: () => f },
|
||||
})
|
||||
expect(k.read('__difyctl_probe__', '__difyctl_probe__')).toBe('')
|
||||
})
|
||||
|
||||
it('removes the probe entry even when the probe read throws', () => {
|
||||
const k = memStore('keyring')
|
||||
const f = memStore('file')
|
||||
const removeSpy = vi.spyOn(k, 'remove')
|
||||
k.read = vi.fn(() => {
|
||||
throw new Error('read boom')
|
||||
}) as TokenStore['read']
|
||||
const result = detectTokenStore({
|
||||
factory: { keyring: () => k, file: () => f },
|
||||
})
|
||||
expect(removeSpy).toHaveBeenCalledWith('__difyctl_probe__', '__difyctl_probe__')
|
||||
expect(result.mode).toBe('file')
|
||||
expect(result.store).toBe(f)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getTokenStore', () => {
|
||||
it('constructs the keychain backend without probing when mode is keychain', () => {
|
||||
const k = memStore('keyring')
|
||||
const f = memStore('file')
|
||||
k.write = vi.fn(() => {
|
||||
throw new Error('probe must never run on the read path')
|
||||
})
|
||||
const store = getTokenStore('keychain', {
|
||||
factory: { keyring: () => k, file: () => f },
|
||||
})
|
||||
expect(store).toBe(k)
|
||||
})
|
||||
|
||||
it('constructs the file backend when mode is file, never touching the keyring', () => {
|
||||
const keyringFactory = vi.fn(() => memStore('keyring'))
|
||||
const f = memStore('file')
|
||||
const store = getTokenStore('file', {
|
||||
factory: { keyring: keyringFactory, file: () => f },
|
||||
})
|
||||
expect(store).toBe(f)
|
||||
expect(keyringFactory).not.toHaveBeenCalled()
|
||||
expect(k.get({ key: '__difyctl_probe__', default: '' })).toBe('')
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
import type { StorageMode, Store } from './store'
|
||||
import type { TokenStore } from './token-store'
|
||||
import type { Key, StorageMode, Store } from './store'
|
||||
import { join } from 'node:path'
|
||||
import { resolveCacheDir, resolveConfigDir } from './dir'
|
||||
import { YamlStore } from './store'
|
||||
import { FileTokenStore, KeychainTokenStore } from './token-store'
|
||||
import { KeyringBasedStore, YamlStore } from './store'
|
||||
|
||||
export const CACHE_APP_INFO = 'app-info'
|
||||
export const CACHE_NUDGE = 'nudge'
|
||||
@ -33,52 +31,51 @@ export function getHostStore(): YamlStore {
|
||||
return getStore(join(resolveConfigDir(), HOSTS_FILE))
|
||||
}
|
||||
|
||||
const PROBE_HOST = '__difyctl_probe__'
|
||||
const PROBE_EMAIL = '__difyctl_probe__'
|
||||
const PROBE_KEY: Key<string> = { key: '__difyctl_probe__', default: '' }
|
||||
const PROBE_VALUE = 'probe-v1'
|
||||
|
||||
export type GetTokenStoreOptions = {
|
||||
readonly factory?: {
|
||||
readonly keyring?: () => TokenStore
|
||||
readonly file?: () => TokenStore
|
||||
readonly keyring?: () => Store
|
||||
readonly file?: () => Store
|
||||
}
|
||||
}
|
||||
|
||||
const TOKEN_STORE_OPENERS: Record<StorageMode, (opts: GetTokenStoreOptions) => TokenStore> = {
|
||||
file: opts => opts.factory?.file?.() ?? new FileTokenStore(join(resolveConfigDir(), TOKENS_FILE)),
|
||||
keychain: opts => opts.factory?.keyring?.() ?? new KeychainTokenStore(KEYRING_SERVICE),
|
||||
}
|
||||
|
||||
/**
|
||||
* Decide which credential backend to use by probing the OS keyring with a
|
||||
* write/read/remove round-trip. The probe MUTATES the keyring, so call this
|
||||
* only where a credential is about to be written anyway (login).
|
||||
* Single entry point for the credential store. Probes the OS keyring; if it
|
||||
* round-trips a value, returns the keychain-backed store. Otherwise falls
|
||||
* back to the YAML file at `<configDir>/tokens.yml`. Both implementations
|
||||
* satisfy the `Store` interface, so callers interact uniformly.
|
||||
*
|
||||
* Business logic should always obtain the token store through this factory
|
||||
* rather than constructing one directly.
|
||||
*/
|
||||
export function detectTokenStore(opts: GetTokenStoreOptions = {}): { store: TokenStore, mode: StorageMode } {
|
||||
export function getTokenStore(opts: GetTokenStoreOptions = {}): { store: Store, mode: StorageMode } {
|
||||
const fileFactory = opts.factory?.file ?? (() => getStore(join(resolveConfigDir(), TOKENS_FILE)))
|
||||
const keyringFactory = opts.factory?.keyring ?? (() => new KeyringBasedStore(KEYRING_SERVICE))
|
||||
// DIFY_E2E_NO_KEYRING=1 forces file-based storage in E2E tests to avoid
|
||||
// macOS keychain UI prompts blocking child processes spawned by vitest.
|
||||
if (process.env.DIFY_E2E_NO_KEYRING === '1')
|
||||
return { store: TOKEN_STORE_OPENERS.file(opts), mode: 'file' }
|
||||
return { store: fileFactory(), mode: 'file' }
|
||||
try {
|
||||
const k = TOKEN_STORE_OPENERS.keychain(opts)
|
||||
k.write(PROBE_HOST, PROBE_EMAIL, PROBE_VALUE)
|
||||
let got = ''
|
||||
try {
|
||||
got = k.read(PROBE_HOST, PROBE_EMAIL)
|
||||
}
|
||||
finally {
|
||||
k.remove(PROBE_HOST, PROBE_EMAIL)
|
||||
}
|
||||
if (got === PROBE_VALUE)
|
||||
return { store: k, mode: 'keychain' }
|
||||
const k = keyringFactory()
|
||||
k.set(PROBE_KEY, PROBE_VALUE)
|
||||
const got = k.get(PROBE_KEY)
|
||||
k.unset(PROBE_KEY)
|
||||
if (got !== PROBE_VALUE)
|
||||
throw new Error('keyring round-trip mismatch')
|
||||
return { store: k, mode: 'keychain' }
|
||||
}
|
||||
catch {
|
||||
return { store: fileFactory(), mode: 'file' }
|
||||
}
|
||||
catch { /* keyring unavailable → fall through to file */ }
|
||||
return { store: TOKEN_STORE_OPENERS.file(opts), mode: 'file' }
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct the credential backend the registry already recorded at login.
|
||||
* Maps an auth identity (host + accountId) to a `Store` key. All token store
|
||||
* reads/writes in business logic go through this helper so the on-disk /
|
||||
* keyring layout stays consistent.
|
||||
*/
|
||||
export function getTokenStore(mode: StorageMode, opts: GetTokenStoreOptions = {}): TokenStore {
|
||||
return TOKEN_STORE_OPENERS[mode](opts)
|
||||
export function tokenKey(host: string, accountId: string): Key<string> {
|
||||
return { key: `tokens.${host}.${accountId}`, default: '' }
|
||||
}
|
||||
|
||||
@ -21,8 +21,7 @@ export type Store = {
|
||||
unset: <T>(key: Key<T>) => void
|
||||
}
|
||||
|
||||
export const STORAGE_MODES = ['keychain', 'file'] as const
|
||||
export type StorageMode = typeof STORAGE_MODES[number]
|
||||
export type StorageMode = 'keychain' | 'file'
|
||||
|
||||
abstract class FileBasedStore implements Store {
|
||||
filePath: string
|
||||
|
||||
@ -1,81 +0,0 @@
|
||||
import { mkdtempSync, readFileSync, rmSync, writeFileSync } from 'node:fs'
|
||||
import { tmpdir } from 'node:os'
|
||||
import { join } from 'node:path'
|
||||
import { afterEach, beforeEach, describe, expect, it } from 'vitest'
|
||||
import { FileTokenStore } from './token-store'
|
||||
|
||||
describe('FileTokenStore', () => {
|
||||
let dir: string
|
||||
let file: string
|
||||
|
||||
beforeEach(() => {
|
||||
dir = mkdtempSync(join(tmpdir(), 'difyctl-tok-'))
|
||||
file = join(dir, 'tokens.yml')
|
||||
})
|
||||
afterEach(() => rmSync(dir, { recursive: true, force: true }))
|
||||
|
||||
it('returns empty string for a missing credential', () => {
|
||||
const s = new FileTokenStore(file)
|
||||
expect(s.read('https://cloud.dify.ai', 'a@x.com')).toBe('')
|
||||
})
|
||||
|
||||
it('round-trips a bearer with dots and @ kept literal', () => {
|
||||
const s = new FileTokenStore(file)
|
||||
s.write('https://cloud.dify.ai', 'a.b@x.com', 'dfoa_secret')
|
||||
expect(s.read('https://cloud.dify.ai', 'a.b@x.com')).toBe('dfoa_secret')
|
||||
})
|
||||
|
||||
it('keeps multiple accounts under one host and isolates hosts', () => {
|
||||
const s = new FileTokenStore(file)
|
||||
s.write('https://cloud.dify.ai', 'a@x.com', 'A')
|
||||
s.write('https://cloud.dify.ai', 'b@x.com', 'B')
|
||||
s.write('https://self.example.com', 'a@x.com', 'C')
|
||||
expect(s.read('https://cloud.dify.ai', 'a@x.com')).toBe('A')
|
||||
expect(s.read('https://cloud.dify.ai', 'b@x.com')).toBe('B')
|
||||
expect(s.read('https://self.example.com', 'a@x.com')).toBe('C')
|
||||
})
|
||||
|
||||
it('persists the versioned nested shape on disk', () => {
|
||||
const s = new FileTokenStore(file)
|
||||
s.write('https://cloud.dify.ai', 'a@x.com', 'A')
|
||||
const raw = readFileSync(file, 'utf8')
|
||||
expect(raw).toContain('version: 1')
|
||||
expect(raw).toContain('https://cloud.dify.ai')
|
||||
expect(raw).toContain('a@x.com')
|
||||
})
|
||||
|
||||
it('reads empty when the document version is an unknown future version', () => {
|
||||
writeFileSync(file, 'version: 999\ntokens:\n "h":\n "e": "x"\n')
|
||||
const s = new FileTokenStore(file)
|
||||
expect(s.read('h', 'e')).toBe('')
|
||||
})
|
||||
|
||||
it('reads tokens from legacy format (no version field) for transparent migration', () => {
|
||||
writeFileSync(file, 'tokens:\n "h":\n "e": "dfoa_legacy"\n')
|
||||
const s = new FileTokenStore(file)
|
||||
expect(s.read('h', 'e')).toBe('dfoa_legacy')
|
||||
})
|
||||
|
||||
it('preserves existing tokens and stamps version when writing to a legacy file', () => {
|
||||
writeFileSync(file, 'tokens:\n "h":\n "existing@x": "dfoa_existing"\n')
|
||||
const s = new FileTokenStore(file)
|
||||
s.write('h', 'new@x', 'dfoa_new')
|
||||
expect(s.read('h', 'existing@x')).toBe('dfoa_existing')
|
||||
expect(s.read('h', 'new@x')).toBe('dfoa_new')
|
||||
expect(readFileSync(file, 'utf8')).toContain('version: 1')
|
||||
})
|
||||
|
||||
it('remove deletes the credential and prunes the empty host map', () => {
|
||||
const s = new FileTokenStore(file)
|
||||
s.write('https://cloud.dify.ai', 'a@x.com', 'A')
|
||||
s.remove('https://cloud.dify.ai', 'a@x.com')
|
||||
expect(s.read('https://cloud.dify.ai', 'a@x.com')).toBe('')
|
||||
const raw = readFileSync(file, 'utf8')
|
||||
expect(raw).not.toContain('cloud.dify.ai')
|
||||
})
|
||||
|
||||
it('remove is a no-op for an absent credential', () => {
|
||||
const s = new FileTokenStore(file)
|
||||
expect(() => s.remove('h', 'e')).not.toThrow()
|
||||
})
|
||||
})
|
||||
@ -1,130 +0,0 @@
|
||||
import { Entry } from '@napi-rs/keyring'
|
||||
import { BaseError } from '@/errors/base'
|
||||
import { ErrorCode } from '@/errors/codes'
|
||||
import { YamlStore } from './store'
|
||||
|
||||
/**
|
||||
* Credential store keyed by an opaque (host, email) pair.
|
||||
*/
|
||||
export type TokenStore = {
|
||||
read: (host: string, email: string) => string
|
||||
write: (host: string, email: string, bearer: string) => void
|
||||
remove: (host: string, email: string) => void
|
||||
}
|
||||
|
||||
const DOC_VERSION = 1
|
||||
|
||||
type TokenDoc = {
|
||||
version?: number
|
||||
tokens?: Record<string, Record<string, string>>
|
||||
}
|
||||
|
||||
export class FileTokenStore implements TokenStore {
|
||||
private readonly store: YamlStore
|
||||
|
||||
constructor(filePath: string) {
|
||||
this.store = new YamlStore(filePath)
|
||||
}
|
||||
|
||||
read(host: string, email: string): string {
|
||||
const doc = this.store.getTyped<TokenDoc>()
|
||||
if (doc === null)
|
||||
return ''
|
||||
// missing version = legacy pre-v1 format (same data shape); future unknown versions are rejected
|
||||
if (doc.version !== undefined && doc.version !== DOC_VERSION)
|
||||
return ''
|
||||
return doc.tokens?.[host]?.[email] ?? ''
|
||||
}
|
||||
|
||||
write(host: string, email: string, bearer: string): void {
|
||||
const doc = this.load()
|
||||
const hostMap = doc.tokens[host] ?? {}
|
||||
hostMap[email] = bearer
|
||||
doc.tokens[host] = hostMap
|
||||
this.store.setTyped(doc)
|
||||
}
|
||||
|
||||
remove(host: string, email: string): void {
|
||||
const doc = this.store.getTyped<TokenDoc>()
|
||||
if (doc === null)
|
||||
return
|
||||
if (doc.version !== undefined && doc.version !== DOC_VERSION)
|
||||
return
|
||||
const tokens = doc.tokens ?? {}
|
||||
const hostMap = tokens[host]
|
||||
if (hostMap === undefined || !(email in hostMap))
|
||||
return
|
||||
delete hostMap[email]
|
||||
if (Object.keys(hostMap).length === 0)
|
||||
delete tokens[host]
|
||||
this.store.setTyped({ version: DOC_VERSION, tokens })
|
||||
}
|
||||
|
||||
private load(): { version: number, tokens: Record<string, Record<string, string>> } {
|
||||
const doc = this.store.getTyped<TokenDoc>()
|
||||
if (doc === null)
|
||||
return { version: DOC_VERSION, tokens: {} }
|
||||
if (doc.version !== undefined && doc.version !== DOC_VERSION)
|
||||
return { version: DOC_VERSION, tokens: {} }
|
||||
return { version: DOC_VERSION, tokens: (doc.tokens ?? {}) as Record<string, Record<string, string>> }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* One OS-keyring entry per (host, email).
|
||||
*/
|
||||
export class KeychainTokenStore implements TokenStore {
|
||||
private readonly service: string
|
||||
|
||||
constructor(service: string) {
|
||||
this.service = service
|
||||
}
|
||||
|
||||
read(host: string, email: string): string {
|
||||
let raw: string | null
|
||||
try {
|
||||
raw = new Entry(this.service, entryName(host, email)).getPassword()
|
||||
}
|
||||
catch (err) {
|
||||
throw keyringUnavailableError(err)
|
||||
}
|
||||
if (raw === null || raw === undefined || raw === '')
|
||||
return ''
|
||||
try {
|
||||
const parsed: unknown = JSON.parse(raw)
|
||||
return typeof parsed === 'string' ? parsed : ''
|
||||
}
|
||||
catch {
|
||||
return ''
|
||||
}
|
||||
}
|
||||
|
||||
write(host: string, email: string, bearer: string): void {
|
||||
try {
|
||||
new Entry(this.service, entryName(host, email)).setPassword(JSON.stringify(bearer))
|
||||
}
|
||||
catch (err) {
|
||||
throw keyringUnavailableError(err)
|
||||
}
|
||||
}
|
||||
|
||||
remove(host: string, email: string): void {
|
||||
try {
|
||||
new Entry(this.service, entryName(host, email)).deletePassword()
|
||||
}
|
||||
catch { /* missing entry is fine */ }
|
||||
}
|
||||
}
|
||||
|
||||
function entryName(host: string, email: string): string {
|
||||
return `tokens.${host}.${email}`
|
||||
}
|
||||
|
||||
function keyringUnavailableError(cause: unknown): BaseError {
|
||||
return new BaseError({
|
||||
code: ErrorCode.KeyringUnavailable,
|
||||
message: 'OS keychain is unreachable',
|
||||
hint: 'credentials are stored in the system keychain but it could not be accessed; unlock the keychain (or the login session) and retry',
|
||||
cause,
|
||||
})
|
||||
}
|
||||
@ -1,29 +0,0 @@
|
||||
import type { ActiveContext } from '@/auth/hosts'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { resolveWorkspaceId } from './resolver'
|
||||
|
||||
function active(workspaceId?: string): ActiveContext {
|
||||
return { host: 'h', email: 'e', ctx: { account: { id: '', email: 'e', name: '' }, workspace: workspaceId ? { id: workspaceId, name: 'W', role: 'owner' } : undefined } }
|
||||
}
|
||||
|
||||
const UUID_FLAG = 'aaaaaaaa-0000-0000-0000-000000000001'
|
||||
const UUID_ENV = 'aaaaaaaa-0000-0000-0000-000000000002'
|
||||
const UUID_CTX = 'aaaaaaaa-0000-0000-0000-000000000003'
|
||||
|
||||
describe('resolveWorkspaceId', () => {
|
||||
it('prefers the flag', () => {
|
||||
expect(resolveWorkspaceId({ flag: UUID_FLAG, env: UUID_ENV, active: active(UUID_CTX) })).toBe(UUID_FLAG)
|
||||
})
|
||||
it('falls back to env over active workspace', () => {
|
||||
expect(resolveWorkspaceId({ env: UUID_ENV, active: active(UUID_CTX) })).toBe(UUID_ENV)
|
||||
})
|
||||
it('falls back to active workspace when no flag or env', () => {
|
||||
expect(resolveWorkspaceId({ active: active(UUID_CTX) })).toBe(UUID_CTX)
|
||||
})
|
||||
it('throws when active workspace ID is not a valid UUID', () => {
|
||||
expect(() => resolveWorkspaceId({ active: active('ws-ctx') })).toThrow(/stored workspace ID/)
|
||||
})
|
||||
it('throws when no workspace is selected (no implicit default)', () => {
|
||||
expect(() => resolveWorkspaceId({ active: active(undefined) })).toThrow(/no workspace selected/)
|
||||
})
|
||||
})
|
||||
@ -25,16 +25,14 @@ export function resolveWorkspaceId(inputs: WorkspaceResolveInputs): string {
|
||||
throw new BaseError({ code: ErrorCode.UsageInvalidFlag, message: `DIFY_WORKSPACE_ID value ${JSON.stringify(inputs.env)} is not a valid UUID` })
|
||||
return inputs.env
|
||||
}
|
||||
const wsId = inputs.active?.ctx.workspace?.id
|
||||
if (truthy(wsId)) {
|
||||
if (!isValidUuid(wsId)) {
|
||||
throw new BaseError({
|
||||
code: ErrorCode.UsageInvalidFlag,
|
||||
message: `stored workspace ID ${JSON.stringify(wsId)} is not a valid UUID`,
|
||||
hint: 'run \'difyctl use workspace\' to update your active workspace',
|
||||
})
|
||||
const ctx = inputs.active?.ctx
|
||||
if (ctx !== undefined) {
|
||||
if (truthy(ctx.workspace?.id))
|
||||
return ctx.workspace.id
|
||||
if (ctx.available_workspaces !== undefined && ctx.available_workspaces.length > 0
|
||||
&& truthy(ctx.available_workspaces[0]?.id)) {
|
||||
return ctx.available_workspaces[0].id
|
||||
}
|
||||
return wsId
|
||||
}
|
||||
throw new BaseError({
|
||||
code: ErrorCode.UsageMissingArg,
|
||||
|
||||
@ -24,6 +24,7 @@ import {
|
||||
zGetHealthResponse,
|
||||
zGetOauthDeviceLookupQuery,
|
||||
zGetOauthDeviceLookupResponse,
|
||||
zGetPermittedExternalAppsQuery,
|
||||
zGetPermittedExternalAppsResponse,
|
||||
zGetVersionResponse,
|
||||
zGetWorkspacesByWorkspaceIdMembersPath,
|
||||
@ -438,6 +439,7 @@ export const get10 = oc
|
||||
path: '/permitted-external-apps',
|
||||
tags: ['openapi'],
|
||||
})
|
||||
.input(z.object({ query: zGetPermittedExternalAppsQuery.optional() }))
|
||||
.output(zGetPermittedExternalAppsResponse)
|
||||
|
||||
export const permittedExternalApps = {
|
||||
|
||||
@ -656,7 +656,12 @@ export type PostOauthDeviceTokenResponse
|
||||
export type GetPermittedExternalAppsData = {
|
||||
body?: never
|
||||
path?: never
|
||||
query?: never
|
||||
query?: {
|
||||
limit?: number
|
||||
mode?: string
|
||||
name?: string
|
||||
page?: number
|
||||
}
|
||||
url: '/permitted-external-apps'
|
||||
}
|
||||
|
||||
|
||||
@ -638,6 +638,13 @@ export const zPostOauthDeviceTokenBody = zDevicePollRequest
|
||||
*/
|
||||
export const zPostOauthDeviceTokenResponse = z.record(z.string(), z.unknown())
|
||||
|
||||
export const zGetPermittedExternalAppsQuery = z.object({
|
||||
limit: z.int().gte(1).lte(200).optional().default(20),
|
||||
mode: z.string().optional(),
|
||||
name: z.string().max(200).optional(),
|
||||
page: z.int().gte(1).optional().default(1),
|
||||
})
|
||||
|
||||
/**
|
||||
* Permitted external apps list
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user