mirror of
https://github.com/langgenius/dify.git
synced 2026-05-28 12:53:23 +08:00
Compare commits
16 Commits
deploy/saa
...
feat/ui-on
| Author | SHA1 | Date | |
|---|---|---|---|
| 95936a8bac | |||
| ac8a1107ca | |||
| 0c96426d91 | |||
| 67fee14770 | |||
| d94006162d | |||
| 3d53cee8a9 | |||
| 1acd1b568a | |||
| 68f939f3b3 | |||
| 1f4b76ba7e | |||
| 4d974d8f72 | |||
| 1dc12d1661 | |||
| 9cdeffd0b1 | |||
| 09ef785a20 | |||
| 82345977cd | |||
| d2788d7aba | |||
| 83c943bc21 |
4
.github/workflows/deploy-agent-dev.yml
vendored
4
.github/workflows/deploy-agent-dev.yml
vendored
@ -7,7 +7,7 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/saas"
|
||||
- "deploy/agent-dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/saas'
|
||||
github.event.workflow_run.head_branch == 'deploy/agent-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
|
||||
|
||||
@ -38,6 +38,8 @@ from clients.agent_backend.request_builder import (
|
||||
AgentBackendOutputConfig,
|
||||
AgentBackendRunRequestBuilder,
|
||||
AgentBackendWorkflowNodeRunInput,
|
||||
CleanupLayerSpec,
|
||||
extract_cleanup_layer_specs,
|
||||
redact_for_agent_backend_log,
|
||||
)
|
||||
|
||||
@ -68,9 +70,11 @@ __all__ = [
|
||||
"AgentBackendTransportError",
|
||||
"AgentBackendValidationError",
|
||||
"AgentBackendWorkflowNodeRunInput",
|
||||
"CleanupLayerSpec",
|
||||
"DifyAgentBackendRunClient",
|
||||
"FakeAgentBackendRunClient",
|
||||
"FakeAgentBackendScenario",
|
||||
"create_agent_backend_run_client",
|
||||
"extract_cleanup_layer_specs",
|
||||
"redact_for_agent_backend_log",
|
||||
]
|
||||
|
||||
@ -20,6 +20,8 @@ from dify_agent.protocol import (
|
||||
RunEvent,
|
||||
RunFailedEvent,
|
||||
RunFailedEventData,
|
||||
RunPausedEvent,
|
||||
RunPausedEventData,
|
||||
RunStartedEvent,
|
||||
RunStatusResponse,
|
||||
RunSucceededEvent,
|
||||
@ -34,6 +36,7 @@ class FakeAgentBackendScenario(StrEnum):
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class FakeAgentBackendRunClient:
|
||||
@ -89,6 +92,13 @@ class FakeAgentBackendRunClient:
|
||||
updated_at=_FIXED_TIME,
|
||||
error="fake failure",
|
||||
)
|
||||
case FakeAgentBackendScenario.PAUSED:
|
||||
return RunStatusResponse(
|
||||
run_id=run_id,
|
||||
status="paused",
|
||||
created_at=_FIXED_TIME,
|
||||
updated_at=_FIXED_TIME,
|
||||
)
|
||||
|
||||
def _events(self, run_id: str) -> tuple[RunEvent, ...]:
|
||||
match self.scenario:
|
||||
@ -115,3 +125,17 @@ class FakeAgentBackendRunClient:
|
||||
data=RunFailedEventData(error="fake failure", reason="unit_test"),
|
||||
),
|
||||
)
|
||||
case FakeAgentBackendScenario.PAUSED:
|
||||
return (
|
||||
RunStartedEvent(id="1-0", run_id=run_id, created_at=_FIXED_TIME),
|
||||
RunPausedEvent(
|
||||
id="2-0",
|
||||
run_id=run_id,
|
||||
created_at=_FIXED_TIME,
|
||||
data=RunPausedEventData(
|
||||
reason="human_input_required",
|
||||
message="Agent requested human input.",
|
||||
session_snapshot=CompositorSessionSnapshot(layers=[]),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@ -11,11 +11,13 @@ composition-driven.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, cast
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from agenton.compositor.schemas import LayerSessionSnapshot
|
||||
from agenton.layers import ExitIntent
|
||||
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
|
||||
from agenton_collections.layers.pydantic_ai import PYDANTIC_AI_HISTORY_LAYER_TYPE_ID
|
||||
from dify_agent.layers.dify_plugin import (
|
||||
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
@ -29,6 +31,7 @@ from dify_agent.layers.execution_context import (
|
||||
)
|
||||
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
|
||||
from dify_agent.protocol import (
|
||||
DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
DIFY_AGENT_MODEL_LAYER_ID,
|
||||
DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
CreateRunRequest,
|
||||
@ -45,6 +48,84 @@ WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID = "tools"
|
||||
|
||||
# Layer types that hold credentials in their per-run config. These are excluded
|
||||
# from the cleanup-replay composition (and from the snapshot that is sent with
|
||||
# the cleanup request) because we deliberately do not persist plaintext
|
||||
# credentials between runs.
|
||||
_CLEANUP_EXCLUDED_LAYER_TYPES: tuple[str, ...] = (
|
||||
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
)
|
||||
|
||||
|
||||
class CleanupLayerSpec(BaseModel):
|
||||
"""One layer node replayed by an Agent backend cleanup-only run.
|
||||
|
||||
Cleanup composition cannot include credential-bearing plugin layers, so we
|
||||
persist only the non-plugin layer specs together with the original config.
|
||||
Storing the config (rather than just ``name``/``type``) means cleanup does
|
||||
not depend on the original build-time inputs being re-derivable.
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
deps: dict[str, str] = Field(default_factory=dict)
|
||||
metadata: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
config: JsonValue = None
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
def extract_cleanup_layer_specs(composition: RunComposition) -> list[CleanupLayerSpec]:
|
||||
"""Project the in-flight composition into the persistable cleanup spec list.
|
||||
|
||||
Plugin layers are intentionally dropped (their configs hold credentials and
|
||||
the lifecycle contract says "do not include an LLM layer" during cleanup).
|
||||
The filtered names must later drive snapshot filtering so the agenton
|
||||
compositor's name-order check still passes for the cleanup run.
|
||||
"""
|
||||
excluded = set(_CLEANUP_EXCLUDED_LAYER_TYPES)
|
||||
specs: list[CleanupLayerSpec] = []
|
||||
for layer in composition.layers:
|
||||
if layer.type in excluded:
|
||||
continue
|
||||
config_value: JsonValue = None
|
||||
if isinstance(layer.config, BaseModel):
|
||||
config_value = layer.config.model_dump(mode="json", warnings=False)
|
||||
else:
|
||||
# ``RunLayerSpec.config`` is typed as ``LayerConfigInput`` which
|
||||
# includes ``Mapping[str, object] | bytes``. In the cleanup-replay
|
||||
# pipeline our builder only emits BaseModel-derived configs or
|
||||
# ``None``, so the wider input alias narrows safely here.
|
||||
config_value = cast(JsonValue, layer.config)
|
||||
specs.append(
|
||||
CleanupLayerSpec(
|
||||
name=layer.name,
|
||||
type=layer.type,
|
||||
deps=dict(layer.deps),
|
||||
metadata=dict(layer.metadata),
|
||||
config=config_value,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
|
||||
def _filter_snapshot_to_specs(
|
||||
snapshot: CompositorSessionSnapshot,
|
||||
specs: list[CleanupLayerSpec],
|
||||
) -> CompositorSessionSnapshot:
|
||||
"""Keep only snapshot layers whose names appear in the cleanup spec list.
|
||||
|
||||
The agenton compositor rejects a snapshot whose layer-name sequence does
|
||||
not match the active composition exactly. Cleanup-replay drops plugin
|
||||
layers, so we must drop the matching snapshot entries here.
|
||||
"""
|
||||
kept_names = {spec.name for spec in specs}
|
||||
filtered_layers: list[LayerSessionSnapshot] = [layer for layer in snapshot.layers if layer.name in kept_names]
|
||||
if len(filtered_layers) == len(snapshot.layers):
|
||||
return snapshot
|
||||
return CompositorSessionSnapshot(schema_version=snapshot.schema_version, layers=filtered_layers)
|
||||
|
||||
|
||||
class AgentBackendModelConfig(BaseModel):
|
||||
"""API-side model/plugin selection before it is converted to Dify Agent layers."""
|
||||
@ -86,7 +167,8 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
output: AgentBackendOutputConfig | None = None
|
||||
tools: DifyPluginToolsLayerConfig | None = None
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
suspend_on_exit: bool = False
|
||||
include_history: bool = True
|
||||
suspend_on_exit: bool = True
|
||||
metadata: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
@ -102,6 +184,50 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
class AgentBackendRunRequestBuilder:
|
||||
"""Converts API product state into the public ``dify-agent`` run protocol."""
|
||||
|
||||
def build_cleanup_request(
|
||||
self,
|
||||
*,
|
||||
session_snapshot: CompositorSessionSnapshot,
|
||||
composition_layer_specs: list[CleanupLayerSpec],
|
||||
idempotency_key: str | None = None,
|
||||
metadata: dict[str, JsonValue] | None = None,
|
||||
) -> CreateRunRequest:
|
||||
"""Build a lifecycle-only cleanup request that replays the prior layers.
|
||||
|
||||
The agenton compositor enforces that the session snapshot's layer names
|
||||
match the active composition in order, so cleanup must replay the same
|
||||
non-plugin layer graph that produced the snapshot. Plugin layers
|
||||
(``dify.plugin.llm``, ``dify.plugin.tools``) are excluded from both the
|
||||
composition and the snapshot before submission because their configs
|
||||
require credentials that are not persisted between runs.
|
||||
"""
|
||||
if not composition_layer_specs:
|
||||
raise ValueError(
|
||||
"build_cleanup_request requires composition_layer_specs; an empty "
|
||||
"composition would fail the agent backend's snapshot validation."
|
||||
)
|
||||
request_metadata = dict(metadata or {})
|
||||
request_metadata["agent_backend_lifecycle"] = "session_cleanup"
|
||||
layers = [
|
||||
RunLayerSpec(
|
||||
name=spec.name,
|
||||
type=spec.type,
|
||||
deps=dict(spec.deps),
|
||||
metadata=dict(spec.metadata),
|
||||
config=spec.config,
|
||||
)
|
||||
for spec in composition_layer_specs
|
||||
]
|
||||
filtered_snapshot = _filter_snapshot_to_specs(session_snapshot, composition_layer_specs)
|
||||
return CreateRunRequest(
|
||||
composition=RunComposition(layers=layers),
|
||||
purpose="workflow_node",
|
||||
idempotency_key=idempotency_key,
|
||||
metadata=request_metadata,
|
||||
session_snapshot=filtered_snapshot,
|
||||
on_exit=LayerExitSignals(default=ExitIntent.DELETE),
|
||||
)
|
||||
|
||||
def build_for_workflow_node(self, run_input: AgentBackendWorkflowNodeRunInput) -> CreateRunRequest:
|
||||
"""Build a workflow Agent Node run request without defining another wire schema."""
|
||||
layers: list[RunLayerSpec] = []
|
||||
@ -135,6 +261,20 @@ class AgentBackendRunRequestBuilder:
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.execution_context,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if run_input.include_history:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_session_history"},
|
||||
)
|
||||
)
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_MODEL_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
|
||||
@ -4,7 +4,7 @@ from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
@ -17,18 +17,17 @@ from controllers.openapi._models import (
|
||||
SessionRow,
|
||||
WorkspacePayload,
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
SubjectType,
|
||||
Scope,
|
||||
TokenType,
|
||||
get_auth_ctx,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.rate_limit import (
|
||||
LIMIT_ME_PER_ACCOUNT,
|
||||
LIMIT_ME_PER_EMAIL,
|
||||
enforce,
|
||||
)
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -42,32 +41,18 @@ from services.oauth_device_flow import (
|
||||
@openapi_ns.route("/account")
|
||||
class AccountApi(Resource):
|
||||
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = get_auth_ctx()
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{auth_data.account_id}")
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
|
||||
else:
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return AccountResponse(
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email,
|
||||
subject_issuer=ctx.subject_issuer,
|
||||
account=None,
|
||||
workspaces=[],
|
||||
default_workspace_id=None,
|
||||
).model_dump(mode="json")
|
||||
|
||||
account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None
|
||||
memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
|
||||
account_id_str = str(auth_data.account_id) if auth_data.account_id else None
|
||||
account = AccountService.get_account_by_id(db.session, account_id_str) if account_id_str else None
|
||||
memberships = TenantService.get_account_memberships(db.session, account_id_str) if account_id_str else []
|
||||
default_ws_id = _pick_default_workspace(memberships)
|
||||
|
||||
return AccountResponse(
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email or (account.email if account else None),
|
||||
subject_type="account",
|
||||
subject_email=account.email if account else None,
|
||||
account=_account_payload(account) if account else None,
|
||||
workspaces=[_workspace_payload(m) for m in memberships],
|
||||
default_workspace_id=default_ws_id,
|
||||
@ -77,19 +62,17 @@ class AccountApi(Resource):
|
||||
@openapi_ns.route("/account/sessions/self")
|
||||
class AccountSessionsSelfApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self):
|
||||
ctx = get_auth_ctx()
|
||||
_require_oauth_subject(ctx)
|
||||
revoke_oauth_token(db.session, redis_client, str(ctx.token_id))
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def delete(self, *, auth_data: AuthData):
|
||||
revoke_oauth_token(db.session, redis_client, str(auth_data.token_id))
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions")
|
||||
class AccountSessionsApi(Resource):
|
||||
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
ctx = get_auth_ctx()
|
||||
now = datetime.now(UTC)
|
||||
page = int(request.args.get("page", "1"))
|
||||
@ -122,10 +105,9 @@ class AccountSessionsApi(Resource):
|
||||
@openapi_ns.route("/account/sessions/<string:session_id>")
|
||||
class AccountSessionByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self, session_id: str):
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def delete(self, session_id: str, *, auth_data: AuthData):
|
||||
ctx = get_auth_ctx()
|
||||
_require_oauth_subject(ctx)
|
||||
|
||||
# 404 (not 403) on cross-subject so the endpoint doesn't leak
|
||||
# token IDs that belong to other subjects.
|
||||
@ -136,13 +118,6 @@ class AccountSessionByIdApi(Resource):
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
def _require_oauth_subject(ctx: AuthContext) -> None:
|
||||
if not ctx.source.startswith("oauth"):
|
||||
raise BadRequest(
|
||||
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
|
||||
)
|
||||
|
||||
|
||||
def _iso(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
@ -16,7 +16,8 @@ import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import AppRunRequest
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
@ -124,8 +125,9 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
|
||||
class AppRunApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
|
||||
@openapi_ns.response(200, "Run result (SSE stream)")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
payload = AppRunRequest.model_validate(body)
|
||||
@ -158,8 +160,9 @@ class AppRunApi(Resource):
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
|
||||
class AppRunTaskStopApi(Resource):
|
||||
@openapi_ns.response(200, "Task stopped")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
return {"result": "success"}
|
||||
|
||||
@ -1,9 +1,4 @@
|
||||
"""GET /openapi/v1/apps and per-app reads.
|
||||
|
||||
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
|
||||
is last → outermost → publishes the auth ContextVar before `require_scope`
|
||||
reads it.
|
||||
"""
|
||||
"""GET /openapi/v1/apps and per-app reads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -28,31 +23,17 @@ from controllers.openapi._models import (
|
||||
AppListRow,
|
||||
TagItem,
|
||||
)
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
Scope,
|
||||
SubjectType,
|
||||
get_auth_ctx,
|
||||
require_scope,
|
||||
require_workspace_member,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppListParams, AppService
|
||||
from services.tag_service import TagService
|
||||
|
||||
_APPS_READ_DECORATORS = [
|
||||
require_scope(Scope.APPS_READ),
|
||||
accept_subjects(SubjectType.ACCOUNT),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
]
|
||||
|
||||
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
|
||||
|
||||
|
||||
@ -66,13 +47,9 @@ _EMPTY_PARAMETERS: dict[str, Any] = {
|
||||
|
||||
|
||||
class AppReadResource(Resource):
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
|
||||
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
|
||||
ctx: AuthContext = get_auth_ctx()
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for membership/exists checks."""
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> App:
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(app_id)
|
||||
is_uuid = True
|
||||
@ -99,8 +76,7 @@ class AppReadResource(Resource):
|
||||
raise Conflict("".join(lines))
|
||||
app = matches[0]
|
||||
|
||||
require_workspace_member(ctx, str(app.tenant_id))
|
||||
return app, ctx
|
||||
return app
|
||||
|
||||
|
||||
def parameters_payload(app: App) -> dict:
|
||||
@ -114,13 +90,14 @@ def parameters_payload(app: App) -> dict:
|
||||
class AppDescribeApi(AppReadResource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
|
||||
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
|
||||
def get(self, app_id: str):
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, app_id: str, *, auth_data: AuthData):
|
||||
try:
|
||||
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
app, _ = self._load(app_id, workspace_id=query.workspace_id)
|
||||
app = self._load(app_id, workspace_id=query.workspace_id)
|
||||
|
||||
requested = query.fields
|
||||
want_info = requested is None or "info" in requested
|
||||
@ -168,20 +145,16 @@ class AppDescribeApi(AppReadResource):
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
|
||||
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
|
||||
def get(self):
|
||||
ctx: AuthContext = get_auth_ctx()
|
||||
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
try:
|
||||
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
workspace_id = query.workspace_id
|
||||
require_workspace_member(ctx, workspace_id)
|
||||
|
||||
empty = (
|
||||
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
|
||||
@ -237,7 +210,7 @@ class AppListApi(Resource):
|
||||
openapi_visible=True,
|
||||
)
|
||||
|
||||
pagination = AppService().get_paginate_apps(str(ctx.account_id), workspace_id, params)
|
||||
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params)
|
||||
if pagination is None:
|
||||
return empty
|
||||
|
||||
|
||||
@ -18,37 +18,27 @@ from controllers.openapi._models import (
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
)
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData, Edition
|
||||
from extensions.ext_database import db
|
||||
from libs.device_flow_security import enterprise_only
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
Scope,
|
||||
SubjectType,
|
||||
require_scope,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.app_permitted_service import list_permitted_apps
|
||||
from services.openapi.license_gate import license_required
|
||||
|
||||
|
||||
@openapi_ns.route("/permitted-external-apps")
|
||||
class PermittedExternalAppsListApi(Resource):
|
||||
method_decorators = [
|
||||
require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL),
|
||||
license_required,
|
||||
accept_subjects(SubjectType.EXTERNAL_SSO),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
enterprise_only,
|
||||
]
|
||||
|
||||
@openapi_ns.response(
|
||||
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
|
||||
)
|
||||
def get(self):
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_READ_PERMITTED_EXTERNAL,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}),
|
||||
edition=frozenset({Edition.EE}),
|
||||
)
|
||||
def get(self, *, auth_data: AuthData):
|
||||
try:
|
||||
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
|
||||
__all__ = ["OAUTH_BEARER_PIPELINE"]
|
||||
__all__ = ["auth_router"]
|
||||
|
||||
@ -1,46 +1,64 @@
|
||||
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
|
||||
|
||||
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
|
||||
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
|
||||
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
|
||||
inline — they don't need `AppAuthzCheck`/`CallerMount`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
SurfaceCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
from controllers.openapi.auth.conditions import (
|
||||
EDITION_CE,
|
||||
EDITION_EE,
|
||||
LOADED_APP_IS_PRIVATE,
|
||||
PATH_HAS_APP_ID,
|
||||
WEBAPP_AUTH_ENABLED,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
AppAuthzStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
from controllers.openapi.auth.data import Edition
|
||||
from controllers.openapi.auth.flow import When
|
||||
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
|
||||
from controllers.openapi.auth.prepare import (
|
||||
load_account,
|
||||
load_app,
|
||||
load_app_access_mode,
|
||||
load_tenant,
|
||||
resolve_external_user,
|
||||
)
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
return AclStrategy()
|
||||
return MembershipStrategy()
|
||||
|
||||
|
||||
OAUTH_BEARER_PIPELINE = Pipeline(
|
||||
BearerCheck(),
|
||||
SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})),
|
||||
ScopeCheck(),
|
||||
AppResolver(),
|
||||
WorkspaceMembershipCheck(),
|
||||
AppAuthzCheck(_resolve_app_authz_strategy),
|
||||
CallerMount(AccountMounter(), EndUserMounter()),
|
||||
from controllers.openapi.auth.verify import (
|
||||
check_acl,
|
||||
check_app_access,
|
||||
check_membership,
|
||||
check_private_app_permission,
|
||||
check_scope,
|
||||
)
|
||||
from libs.oauth_bearer import TokenType
|
||||
|
||||
account_pipeline = AuthPipeline(
|
||||
prepare=[
|
||||
When(PATH_HAS_APP_ID, then=load_app),
|
||||
When(PATH_HAS_APP_ID, then=load_tenant),
|
||||
load_account, # all tokens here are account tokens
|
||||
When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
check_scope,
|
||||
When(EDITION_CE & PATH_HAS_APP_ID, then=check_membership),
|
||||
When(EDITION_EE & PATH_HAS_APP_ID & ~WEBAPP_AUTH_ENABLED, then=check_app_access),
|
||||
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
],
|
||||
)
|
||||
|
||||
external_sso_pipeline = AuthPipeline(
|
||||
prepare=[
|
||||
When(PATH_HAS_APP_ID, then=load_app),
|
||||
When(PATH_HAS_APP_ID, then=load_tenant),
|
||||
When(PATH_HAS_APP_ID, then=resolve_external_user),
|
||||
When(PATH_HAS_APP_ID, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
check_scope,
|
||||
When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
],
|
||||
)
|
||||
|
||||
auth_router = PipelineRouter(
|
||||
{
|
||||
TokenType.OAUTH_ACCOUNT: PipelineRoute(account_pipeline),
|
||||
TokenType.OAUTH_EXTERNAL_SSO: PipelineRoute(external_sso_pipeline, required_edition=frozenset({Edition.EE})),
|
||||
}
|
||||
)
|
||||
|
||||
53
api/controllers/openapi/auth/conditions.py
Normal file
53
api/controllers/openapi/auth/conditions.py
Normal file
@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from controllers.openapi.auth.data import AuthData, Edition, RequestContext, current_edition
|
||||
from libs.oauth_bearer import TokenType
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
CondFn = Callable[[RequestContext, AuthData | None], bool]
|
||||
|
||||
|
||||
class Cond:
|
||||
def __init__(self, fn: CondFn) -> None:
|
||||
self._fn = fn
|
||||
|
||||
def __call__(self, ctx: RequestContext, data: AuthData | None = None) -> bool:
|
||||
return self._fn(ctx, data)
|
||||
|
||||
def __and__(self, other: Cond) -> Cond:
|
||||
return Cond(lambda ctx, data: self(ctx, data) and other(ctx, data))
|
||||
|
||||
def __or__(self, other: Cond) -> Cond:
|
||||
return Cond(lambda ctx, data: self(ctx, data) or other(ctx, data))
|
||||
|
||||
def __invert__(self) -> Cond:
|
||||
return Cond(lambda ctx, data: not self(ctx, data))
|
||||
|
||||
|
||||
def request_cond(fn: Callable[[RequestContext], bool]) -> Cond:
|
||||
return Cond(lambda ctx, _: fn(ctx))
|
||||
|
||||
|
||||
def data_cond(fn: Callable[[AuthData], bool]) -> Cond:
|
||||
return Cond(lambda _, data: data is not None and fn(data))
|
||||
|
||||
|
||||
def config_cond(fn: Callable[[], bool]) -> Cond:
|
||||
return Cond(lambda _, __: fn())
|
||||
|
||||
|
||||
TOKEN_IS_OAUTH_ACCOUNT = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_ACCOUNT)
|
||||
TOKEN_IS_OAUTH_EXTERNAL_SSO = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_EXTERNAL_SSO)
|
||||
|
||||
PATH_HAS_APP_ID = request_cond(lambda ctx: "app_id" in ctx.path_params)
|
||||
|
||||
EDITION_CE = config_cond(lambda: current_edition() == Edition.CE)
|
||||
EDITION_EE = config_cond(lambda: current_edition() == Edition.EE)
|
||||
EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
|
||||
|
||||
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
|
||||
|
||||
LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE)
|
||||
@ -1,68 +0,0 @@
|
||||
"""Mutable per-request context for the openapi auth pipeline.
|
||||
|
||||
Every field starts None / empty and is filled in by a step. The pipeline
|
||||
is the only thing that should construct or mutate Context — handlers
|
||||
read populated values via the decorator's kwargs unpacking.
|
||||
|
||||
Context is intentionally decoupled from Flask's ``Request``: the pipeline
|
||||
guard extracts whatever transport-level inputs the steps need (bearer
|
||||
token, path params) at the boundary and writes them into Context fields,
|
||||
so steps stay testable without a request object and won't leak coupling
|
||||
to a specific framework.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from contextvars import Token
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal, Protocol
|
||||
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import App, Tenant
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
required_scope: Scope
|
||||
bearer_token: str | None = None
|
||||
path_params: Mapping[str, str] = field(default_factory=dict)
|
||||
subject_type: SubjectType | None = None
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope] = field(default_factory=frozenset)
|
||||
token_id: uuid.UUID | None = None
|
||||
token_hash: str | None = None
|
||||
cached_verified_tenants: dict[str, bool] | None = None
|
||||
source: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
caller: object | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
auth_ctx_reset_token: Token[AuthContext] | None = None
|
||||
|
||||
@property
|
||||
def must_tenant(self) -> Tenant:
|
||||
if not self.tenant:
|
||||
raise Unauthorized("tenant is not associated")
|
||||
return self.tenant
|
||||
|
||||
@property
|
||||
def must_subject_type(self) -> SubjectType:
|
||||
if not self.subject_type:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
return self.subject_type
|
||||
|
||||
|
||||
class Step(Protocol):
|
||||
"""One responsibility. Mutate ctx or raise to short-circuit."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None: ...
|
||||
69
api/controllers/openapi/auth/data.py
Normal file
69
api/controllers/openapi/auth/data.py
Normal file
@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from configs import dify_config
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import Account, Tenant
|
||||
from models.model import App, EndUser
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
|
||||
class Edition(StrEnum):
|
||||
CE = "ce"
|
||||
EE = "ee"
|
||||
SAAS = "saas"
|
||||
|
||||
|
||||
def current_edition() -> Edition:
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
return Edition.SAAS
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return Edition.EE
|
||||
return Edition.CE
|
||||
|
||||
|
||||
class ExternalIdentity(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
email: str
|
||||
issuer: str | None = None
|
||||
|
||||
|
||||
class RequestContext(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
token_type: TokenType
|
||||
scope: Scope | None = None
|
||||
path_params: dict[str, str]
|
||||
|
||||
|
||||
class AuthData(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
required_scope: Scope | None = None
|
||||
token_type: TokenType
|
||||
account_id: uuid.UUID | None = None
|
||||
token_hash: str
|
||||
token_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope]
|
||||
tenants: dict[str, bool] = Field(default_factory=dict)
|
||||
external_identity: ExternalIdentity | None = None
|
||||
path_params: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
app_access_mode: WebAppAccessMode | None = None
|
||||
|
||||
caller: Account | EndUser | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
|
||||
def require_app_context(self) -> tuple[App, Account | EndUser, Literal["account", "end_user"]]:
|
||||
if self.app is None or self.caller is None or self.caller_kind is None:
|
||||
raise InternalServerError("pipeline_invariant_violated: app context missing")
|
||||
return self.app, self.caller, self.caller_kind
|
||||
19
api/controllers/openapi/auth/flow.py
Normal file
19
api/controllers/openapi/auth/flow.py
Normal file
@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from controllers.openapi.auth.conditions import Cond
|
||||
from controllers.openapi.auth.data import AuthData, RequestContext
|
||||
|
||||
|
||||
class When:
|
||||
def __init__(self, condition: Cond, *, then: Callable[[Any], None]) -> None:
|
||||
self.condition = condition
|
||||
self._step = then
|
||||
|
||||
def applies(self, ctx: RequestContext, data: AuthData | None = None) -> bool:
|
||||
return self.condition(ctx, data)
|
||||
|
||||
def __call__(self, arg: Any) -> None:
|
||||
self._step(arg)
|
||||
@ -1,51 +1,209 @@
|
||||
"""Pipeline IS the auth scheme.
|
||||
"""Auth pipeline — entry point for all openapi auth.
|
||||
|
||||
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
|
||||
that is the design lock-in: forgetting an auth layer is structurally
|
||||
impossible because there is no "sometimes wrap, sometimes don't" choice.
|
||||
`PipelineRouter.guard()` is the only attachment point for endpoints.
|
||||
`AuthPipeline` is a pure step-runner with no routing concerns.
|
||||
`PipelineRoute` binds a pipeline to optional edition requirements.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context, Step
|
||||
from libs.oauth_bearer import Scope, extract_bearer, reset_auth_ctx
|
||||
from controllers.openapi._audit import emit_wrong_surface
|
||||
from controllers.openapi.auth.data import (
|
||||
AuthData,
|
||||
Edition,
|
||||
ExternalIdentity,
|
||||
RequestContext,
|
||||
current_edition,
|
||||
)
|
||||
from controllers.openapi.auth.flow import When
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
Scope,
|
||||
TokenType,
|
||||
extract_bearer,
|
||||
get_authenticator,
|
||||
reset_auth_ctx,
|
||||
set_auth_ctx,
|
||||
)
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self, *steps: Step) -> None:
|
||||
self._steps = steps
|
||||
class AuthPipeline:
|
||||
"""Pure step-runner — no routing, no guard.
|
||||
|
||||
def run(self, ctx: Context) -> None:
|
||||
for step in self._steps:
|
||||
step(ctx)
|
||||
Both `prepare` and `auth` steps receive the same `AuthData` instance.
|
||||
`prepare` steps populate it; `auth` steps validate it.
|
||||
"""
|
||||
|
||||
def guard(self, *, scope: Scope):
|
||||
def decorator(view):
|
||||
def __init__(self, prepare: list, auth: list) -> None:
|
||||
self._prepare = prepare
|
||||
self._auth = auth
|
||||
|
||||
def _run(
|
||||
self,
|
||||
identity: AuthContext,
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
view: Callable,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
) -> Any:
|
||||
req_ctx = RequestContext(
|
||||
token_type=identity.token_type,
|
||||
scope=scope,
|
||||
path_params=dict(request.view_args or {}),
|
||||
)
|
||||
|
||||
data = AuthData(
|
||||
token_type=identity.token_type,
|
||||
account_id=identity.account_id,
|
||||
token_hash=identity.token_hash,
|
||||
token_id=identity.token_id,
|
||||
scopes=frozenset(identity.scopes),
|
||||
tenants=dict(identity.verified_tenants),
|
||||
required_scope=scope,
|
||||
path_params=dict(req_ctx.path_params),
|
||||
external_identity=(
|
||||
ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer)
|
||||
if identity.subject_email
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
for step in self._prepare:
|
||||
if _should_run(step, req_ctx, data=None):
|
||||
step(data)
|
||||
|
||||
for step in self._auth:
|
||||
if _should_run(step, req_ctx, data=data):
|
||||
step(data)
|
||||
|
||||
reset_token = set_auth_ctx(identity)
|
||||
if data.caller:
|
||||
_mount_flask_login(data.caller)
|
||||
|
||||
try:
|
||||
kwargs["auth_data"] = data
|
||||
return view(*args, **kwargs)
|
||||
finally:
|
||||
reset_auth_ctx(reset_token)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineRoute:
|
||||
pipeline: AuthPipeline
|
||||
required_edition: frozenset[Edition] | None = None
|
||||
|
||||
|
||||
class PipelineRouter:
|
||||
"""Entry point for openapi auth.
|
||||
|
||||
`guard()` is the decorator that endpoints attach to. It applies
|
||||
global gates (edition, token type) then dispatches to the matching
|
||||
`PipelineRoute` for the token type.
|
||||
"""
|
||||
|
||||
def __init__(self, routes: dict[TokenType, PipelineRoute]) -> None:
|
||||
self._routes = routes
|
||||
|
||||
def guard(
|
||||
self,
|
||||
*,
|
||||
scope: Scope | None = None,
|
||||
allowed_token_types: frozenset[TokenType] | None = None,
|
||||
edition: frozenset[Edition] | None = None,
|
||||
) -> Callable:
|
||||
def decorator(view: Callable) -> Callable:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# Extract transport-level inputs at the boundary so steps
|
||||
# stay decoupled from Flask's request object.
|
||||
ctx = Context(
|
||||
required_scope=scope,
|
||||
bearer_token=extract_bearer(request),
|
||||
path_params=dict(request.view_args or {}),
|
||||
def decorated(*args: Any, **kwargs: Any) -> Any:
|
||||
return self._execute(
|
||||
args,
|
||||
kwargs,
|
||||
view,
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
)
|
||||
try:
|
||||
self.run(ctx)
|
||||
kwargs.update(
|
||||
app_model=ctx.app,
|
||||
caller=ctx.caller,
|
||||
caller_kind=ctx.caller_kind,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
finally:
|
||||
if ctx.auth_ctx_reset_token is not None:
|
||||
reset_auth_ctx(ctx.auth_ctx_reset_token)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
|
||||
def _execute(
|
||||
self,
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
view: Callable,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
allowed_token_types: frozenset[TokenType] | None,
|
||||
edition: frozenset[Edition] | None,
|
||||
) -> Any:
|
||||
# 404 not 403 — this edition doesn't expose the feature at all
|
||||
if edition is not None and current_edition() not in edition:
|
||||
raise NotFound()
|
||||
|
||||
license_checked = False
|
||||
if edition is not None and Edition.EE in edition:
|
||||
_check_license()
|
||||
license_checked = True
|
||||
|
||||
token = extract_bearer(request)
|
||||
if not token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
identity = get_authenticator().authenticate(token)
|
||||
|
||||
if allowed_token_types is not None and identity.token_type not in allowed_token_types:
|
||||
emit_wrong_surface(
|
||||
subject_type=_subject_type_str(identity),
|
||||
attempted_path=request.path,
|
||||
client_id=getattr(identity, "client_id", None),
|
||||
token_id=str(identity.token_id) if identity.token_id else None,
|
||||
)
|
||||
raise Forbidden("unsupported_token_type")
|
||||
|
||||
route = self._routes.get(identity.token_type)
|
||||
if route is None:
|
||||
raise Forbidden("unsupported_token_type")
|
||||
|
||||
if route.required_edition is not None:
|
||||
if current_edition() not in route.required_edition:
|
||||
raise Forbidden("external_sso_requires_ee")
|
||||
if not license_checked and Edition.EE in route.required_edition:
|
||||
_check_license()
|
||||
|
||||
return route.pipeline._run(identity, args, kwargs, view, scope=scope)
|
||||
|
||||
|
||||
def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool:
|
||||
if isinstance(step, When):
|
||||
return step.applies(req_ctx, data)
|
||||
return True
|
||||
|
||||
|
||||
def _subject_type_str(identity: Any) -> str | None:
|
||||
subject = getattr(identity, "subject_type", None)
|
||||
if subject is None:
|
||||
return None
|
||||
return subject.value if hasattr(subject, "value") else str(subject)
|
||||
|
||||
|
||||
def _check_license() -> None:
|
||||
settings = FeatureService.get_system_features()
|
||||
if settings.license.status in {LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST}:
|
||||
raise Forbidden("license_invalid")
|
||||
|
||||
|
||||
def _mount_flask_login(user: Any) -> None:
|
||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore[attr-defined]
|
||||
user_logged_in.send(current_app._get_current_object(), user=user) # type: ignore[attr-defined]
|
||||
|
||||
67
api/controllers/openapi/auth/prepare.py
Normal file
67
api/controllers/openapi/auth/prepare.py
Normal file
@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantStatus
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
|
||||
|
||||
|
||||
def load_app(data: AuthData) -> None:
|
||||
app_id = data.path_params["app_id"]
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
data.app = app
|
||||
|
||||
|
||||
def load_tenant(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
raise InternalServerError("pipeline_invariant_violated: app not loaded before load_tenant")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, str(data.app.tenant_id))
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
data.tenant = tenant
|
||||
|
||||
|
||||
def load_account(data: AuthData) -> None:
|
||||
account = AccountService.get_account_by_id(db.session, str(data.account_id))
|
||||
if account is None:
|
||||
raise Unauthorized("account not found")
|
||||
if data.tenant:
|
||||
account.current_tenant = data.tenant
|
||||
data.caller = account
|
||||
data.caller_kind = "account"
|
||||
|
||||
|
||||
def resolve_external_user(data: AuthData) -> None:
|
||||
if data.tenant is None or data.app is None or data.external_identity is None:
|
||||
raise Unauthorized("missing context for external user resolution")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=str(data.tenant.id),
|
||||
app_id=str(data.app.id),
|
||||
user_id=data.external_identity.email,
|
||||
)
|
||||
data.caller = end_user
|
||||
data.caller_kind = "end_user"
|
||||
|
||||
|
||||
def load_app_access_mode(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
return
|
||||
try:
|
||||
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(data.app.id))
|
||||
if settings is None:
|
||||
data.app_access_mode = None
|
||||
return
|
||||
data.app_access_mode = WebAppAccessMode(settings.access_mode)
|
||||
except ValueError:
|
||||
data.app_access_mode = None
|
||||
@ -1,170 +0,0 @@
|
||||
"""Pipeline steps. Each is one responsibility.
|
||||
|
||||
`BearerCheck` is the only step that touches the token registry; downstream
|
||||
steps see only the populated `Context`. `BearerCheck` also publishes the
|
||||
resolved identity to the openapi auth ``ContextVar`` (the same one the
|
||||
decorator-level :func:`libs.oauth_bearer.validate_bearer` writes to) so the
|
||||
surface gate and any handler reading the request-scoped context has a single
|
||||
source of truth across both auth-attach paths. The reset token is stashed
|
||||
on `ctx.auth_ctx_reset_token`; `Pipeline.guard` resets the ContextVar in
|
||||
its `finally` so worker-thread reuse can't leak identity across requests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
|
||||
from controllers.openapi.auth.surface_gate import check_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
InvalidBearerError,
|
||||
Scope,
|
||||
SubjectType,
|
||||
check_workspace_membership,
|
||||
get_authenticator,
|
||||
set_auth_ctx,
|
||||
)
|
||||
from models import TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
class BearerCheck:
|
||||
"""Resolve bearer → populate identity fields. Rate-limit is enforced
|
||||
inside `BearerAuthenticator.authenticate`, so no separate step here.
|
||||
Also publishes the resolved `AuthContext` via
|
||||
:func:`libs.oauth_bearer.set_auth_ctx` — same shape the decorator-level
|
||||
``validate_bearer`` writes — so the surface gate + downstream readers
|
||||
don't see two different identity sources. The reset token is parked on
|
||||
``ctx.auth_ctx_reset_token`` for `Pipeline.guard` to consume."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not ctx.bearer_token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
try:
|
||||
authn = get_authenticator().authenticate(ctx.bearer_token)
|
||||
except InvalidBearerError as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
ctx.subject_type = authn.subject_type
|
||||
ctx.subject_email = authn.subject_email
|
||||
ctx.subject_issuer = authn.subject_issuer
|
||||
ctx.account_id = authn.account_id
|
||||
ctx.scopes = frozenset(authn.scopes)
|
||||
ctx.source = authn.source
|
||||
ctx.token_id = authn.token_id
|
||||
ctx.expires_at = authn.expires_at
|
||||
ctx.token_hash = authn.token_hash
|
||||
ctx.cached_verified_tenants = dict(authn.verified_tenants)
|
||||
ctx.auth_ctx_reset_token = set_auth_ctx(authn)
|
||||
|
||||
|
||||
class ScopeCheck:
|
||||
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
class SurfaceCheck:
|
||||
"""Reject the request if the resolved subject is not in `accepted`."""
|
||||
|
||||
def __init__(self, *, accepted: frozenset[SubjectType]) -> None:
|
||||
self._accepted = accepted
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
check_surface(self._accepted)
|
||||
|
||||
|
||||
class AppResolver:
|
||||
"""Read ``app_id`` from ``ctx.path_params``; populate ctx.app + ctx.tenant.
|
||||
|
||||
Every endpoint using the OAuth bearer pipeline must declare
|
||||
``<string:app_id>`` in its route — that is the design lock-in (no body /
|
||||
header coupling). ``Pipeline.guard`` lifts ``request.view_args`` into
|
||||
``ctx.path_params`` at the boundary so this step doesn't need to know
|
||||
about the request object.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
app_id = ctx.path_params.get("app_id")
|
||||
if not app_id:
|
||||
raise BadRequest("app_id is required in path")
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
ctx.app, ctx.tenant = app, tenant
|
||||
|
||||
|
||||
class WorkspaceMembershipCheck:
|
||||
"""Layer 0 — workspace membership gate.
|
||||
|
||||
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
|
||||
(dfoa_) only — SSO subjects skip.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
if ctx.subject_type != SubjectType.ACCOUNT:
|
||||
return
|
||||
if ctx.account_id is None or ctx.tenant is None:
|
||||
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
|
||||
if ctx.token_hash is None:
|
||||
raise Unauthorized("token_hash unset — BearerCheck did not run")
|
||||
|
||||
check_workspace_membership(
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=ctx.must_tenant.id,
|
||||
token_hash=ctx.token_hash,
|
||||
cached_verdicts=ctx.cached_verified_tenants or {},
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzCheck:
|
||||
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
|
||||
self._resolve = resolve_strategy
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not self._resolve().authorize(ctx):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
class CallerMount:
|
||||
def __init__(self, *mounters: CallerMounter) -> None:
|
||||
self._mounters = mounters
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if ctx.subject_type is None:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
for m in self._mounters:
|
||||
if m.applies_to(ctx.must_subject_type):
|
||||
m.mount(ctx)
|
||||
return
|
||||
raise Unauthorized("no caller mounter for subject type")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AppAuthzCheck",
|
||||
"AppResolver",
|
||||
"AuthContext",
|
||||
"BearerCheck",
|
||||
"CallerMount",
|
||||
"ScopeCheck",
|
||||
"SurfaceCheck",
|
||||
"WorkspaceMembershipCheck",
|
||||
]
|
||||
@ -1,168 +0,0 @@
|
||||
"""Strategy classes for the openapi auth pipeline.
|
||||
|
||||
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
|
||||
vary along independent axes; each strategy is one class so the pipeline
|
||||
composition stays a flat list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import user_logged_in
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import (
|
||||
EnterpriseService,
|
||||
WebAppAccessMode,
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzStrategy(Protocol):
|
||||
def authorize(self, ctx: Context) -> bool: ...
|
||||
|
||||
|
||||
class AclStrategy:
|
||||
"""Per-app ACL, evaluated in two stages.
|
||||
|
||||
The EE gateway has already enforced tenancy and workspace membership
|
||||
by the time this strategy runs, so AclStrategy only owns per-app ACL:
|
||||
|
||||
1. Subject vs access-mode compatibility (pure rule table). External-SSO
|
||||
bearers belong to public-facing apps only; account bearers cover the
|
||||
full set. A mismatch is an immediate deny — no IO.
|
||||
2. For modes that pair with the subject, decide whether the inner
|
||||
permission API must run. Only `PRIVATE` (per-app selected-user list)
|
||||
requires it; the remaining modes are pass-through.
|
||||
"""
|
||||
|
||||
_ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = {
|
||||
SubjectType.ACCOUNT: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
WebAppAccessMode.PRIVATE_ALL,
|
||||
WebAppAccessMode.PRIVATE,
|
||||
}
|
||||
),
|
||||
SubjectType.EXTERNAL_SSO: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
_MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE})
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
access_mode = self._fetch_access_mode(ctx.app.id)
|
||||
if access_mode is None:
|
||||
return False
|
||||
if not self._subject_allowed_for_mode(ctx.must_subject_type, access_mode):
|
||||
return False
|
||||
if access_mode not in self._MODES_REQUIRING_INNER_CHECK:
|
||||
return True
|
||||
return self._inner_permission_check(ctx)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None:
|
||||
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
|
||||
if settings is None:
|
||||
return None
|
||||
try:
|
||||
return WebAppAccessMode(settings.access_mode)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool:
|
||||
return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset())
|
||||
|
||||
def _inner_permission_check(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
user_id = self._resolve_user_id(ctx)
|
||||
if user_id is None:
|
||||
return False
|
||||
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id,
|
||||
app_id=ctx.app.id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_id(ctx: Context) -> str | None:
|
||||
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||
return str(ctx.account_id) if ctx.account_id is not None else None
|
||||
if ctx.subject_email is None:
|
||||
return None
|
||||
account = AccountService.get_account_by_email(db.session, ctx.subject_email)
|
||||
return str(account.id) if account is not None else None
|
||||
|
||||
|
||||
class MembershipStrategy:
|
||||
"""Tenant-membership fallback.
|
||||
|
||||
Used when webapp-auth is disabled (CE deployment). Account-bearing
|
||||
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
|
||||
denied (it requires the webapp-auth surface).
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return False
|
||||
if ctx.tenant is None:
|
||||
return False
|
||||
return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id)
|
||||
|
||||
|
||||
def _login_as(user) -> None:
|
||||
"""Set Flask-Login request user so downstream services see the caller."""
|
||||
current_app.login_manager._update_request_context_with_user(user) # type:ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=user) # type:ignore
|
||||
|
||||
|
||||
class CallerMounter(Protocol):
|
||||
def applies_to(self, subject_type: SubjectType) -> bool: ...
|
||||
|
||||
def mount(self, ctx: Context) -> None: ...
|
||||
|
||||
|
||||
class AccountMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.ACCOUNT
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.account_id is None:
|
||||
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
|
||||
account = AccountService.get_account_by_id(db.session, str(ctx.account_id))
|
||||
if account is None:
|
||||
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
|
||||
account.current_tenant = ctx.must_tenant
|
||||
_login_as(account)
|
||||
ctx.caller, ctx.caller_kind = account, "account"
|
||||
|
||||
|
||||
class EndUserMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.EXTERNAL_SSO
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
|
||||
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=ctx.tenant.id,
|
||||
app_id=ctx.app.id,
|
||||
user_id=ctx.subject_email,
|
||||
)
|
||||
_login_as(end_user)
|
||||
ctx.caller, ctx.caller_kind = end_user, "end_user"
|
||||
82
api/controllers/openapi/auth/verify.py
Normal file
82
api/controllers/openapi/auth/verify.py
Normal file
@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType, check_workspace_membership
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
|
||||
|
||||
|
||||
def check_scope(data: AuthData) -> None:
|
||||
if data.required_scope is None:
|
||||
return
|
||||
if Scope.FULL in data.scopes or data.required_scope in data.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
def check_membership(data: AuthData) -> None:
|
||||
if data.tenant is None:
|
||||
raise Unauthorized("tenant unset")
|
||||
if data.account_id is None:
|
||||
raise Unauthorized("account_id unset")
|
||||
check_workspace_membership(
|
||||
account_id=data.account_id,
|
||||
tenant_id=data.tenant.id,
|
||||
token_hash=data.token_hash,
|
||||
membership_cache=data.tenants,
|
||||
)
|
||||
|
||||
|
||||
def check_app_access(data: AuthData) -> None:
|
||||
if data.tenant is None:
|
||||
return
|
||||
if not TenantService.account_belongs_to_tenant(db.session, data.account_id, data.tenant.id):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
_ALLOWED_MODES_BY_TOKEN_TYPE: dict[TokenType, frozenset[WebAppAccessMode]] = {
|
||||
TokenType.OAUTH_ACCOUNT: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
WebAppAccessMode.PRIVATE_ALL,
|
||||
WebAppAccessMode.PRIVATE,
|
||||
}
|
||||
),
|
||||
TokenType.OAUTH_EXTERNAL_SSO: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def check_acl(data: AuthData) -> None:
|
||||
if data.app is None or data.app_access_mode is None:
|
||||
raise Forbidden("app or access mode not loaded")
|
||||
allowed_modes = _ALLOWED_MODES_BY_TOKEN_TYPE.get(data.token_type, frozenset())
|
||||
if data.app_access_mode not in allowed_modes:
|
||||
raise Forbidden("subject_not_allowed_for_access_mode")
|
||||
|
||||
|
||||
def check_private_app_permission(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
raise Forbidden("app not loaded")
|
||||
user_id = _resolve_user_id(data)
|
||||
if user_id is None:
|
||||
raise Forbidden("cannot resolve user for private app check")
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id=user_id, app_id=data.app.id):
|
||||
raise Forbidden("user_not_allowed_for_private_app")
|
||||
|
||||
|
||||
def _resolve_user_id(data: AuthData) -> str | None:
|
||||
if data.token_type == TokenType.OAUTH_ACCOUNT:
|
||||
return str(data.account_id) if data.account_id is not None else None
|
||||
if data.external_identity is None:
|
||||
return None
|
||||
account = AccountService.get_account_by_email(db.session, data.external_identity.email)
|
||||
return str(account.id) if account is not None else None
|
||||
@ -17,11 +17,11 @@ from controllers.common.errors import (
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse
|
||||
from libs.oauth_bearer import Scope
|
||||
from models import Account, App
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -39,8 +39,9 @@ class AppFileUploadApi(Resource):
|
||||
}
|
||||
)
|
||||
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller: Account, caller_kind: str):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, _ = auth_data.require_app_context()
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
if len(request.files) > 1:
|
||||
|
||||
@ -17,7 +17,8 @@ from werkzeug.exceptions import BadRequest, NotFound
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import to_timestamp
|
||||
@ -55,8 +56,9 @@ def _ensure_form_is_allowed_for_openapi(form) -> None:
|
||||
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
|
||||
class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
@openapi_ns.response(200, "Form definition")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, form_token: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
@ -69,8 +71,9 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@openapi_ns.response(200, "Form submitted")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, form_token: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
|
||||
@ -17,7 +17,8 @@ from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
@ -28,7 +29,7 @@ from core.workflow.human_input_policy import HumanInputSurface
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
@ -36,8 +37,9 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/events")
|
||||
class OpenApiWorkflowEventsApi(Resource):
|
||||
@openapi_ns.response(200, "SSE event stream")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
|
||||
raise UnprocessableEntity("mode_not_supported_for_event_reconnect")
|
||||
|
||||
@ -35,15 +35,11 @@ from controllers.openapi._models import (
|
||||
WorkspaceListResponse,
|
||||
WorkspaceSummaryResponse,
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.role_gate import require_workspace_role
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
SubjectType,
|
||||
get_auth_ctx,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import Account, Tenant, TenantAccountJoin
|
||||
from models.account import TenantAccountRole, TenantStatus
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
@ -60,11 +56,6 @@ from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _validate_body[M: BaseModel](model: type[M]) -> M:
|
||||
"""Validate JSON body against ``model``. Validation errors → HTTP 400.
|
||||
|
||||
The workspace spec is explicit that bad email / unknown role payloads
|
||||
are 400, not Pydantic's default 422 — handle uniformly here.
|
||||
"""
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
return model.model_validate(body)
|
||||
@ -91,7 +82,6 @@ def _load_tenant(workspace_id: str) -> Tenant:
|
||||
|
||||
|
||||
def _load_account(account_id: object) -> Account:
|
||||
"""Load the caller's Account. Missing == auth wiring bug, not user error."""
|
||||
account = AccountService.get_account_by_id(db.session, str(account_id)) if account_id else None
|
||||
if account is None:
|
||||
raise RuntimeError("authenticated account_id has no Account row")
|
||||
@ -99,13 +89,6 @@ def _load_account(account_id: object) -> Account:
|
||||
|
||||
|
||||
def _quota_error(*, code: str, message: str, hint: str) -> Forbidden:
|
||||
"""Build a 403 with envelope ``{code, message, hint}``.
|
||||
|
||||
CLI ``error-mapper`` reads ``message`` and ``hint`` off the wire body
|
||||
verbatim — the structured envelope lets it surface remediation guidance
|
||||
(e.g. "upgrade your plan") without the CLI needing to know edition
|
||||
semantics.
|
||||
"""
|
||||
err = Forbidden(message)
|
||||
err.response = make_response(
|
||||
jsonify({"code": code, "message": message, "hint": hint}),
|
||||
@ -115,16 +98,6 @@ def _quota_error(*, code: str, message: str, hint: str) -> Forbidden:
|
||||
|
||||
|
||||
def _check_member_invite_quota(tenant_id: str) -> None:
|
||||
"""Edition-aware member-count gate for invite.
|
||||
|
||||
Both branches self-disable on CE because ``FeatureService.get_features``
|
||||
leaves ``billing.enabled`` and ``workspace_members.enabled`` False by
|
||||
default; SaaS billing API and EE license activation are what flip them on.
|
||||
|
||||
Mirrors the two checks the console invite path performs (decorator at
|
||||
``console/wraps.py:106`` for billing + inline at
|
||||
``console/workspace/members.py:130`` for license).
|
||||
"""
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
@ -148,12 +121,9 @@ def _check_member_invite_quota(tenant_id: str) -> None:
|
||||
@openapi_ns.route("/workspaces")
|
||||
class WorkspacesApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def get(self):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
rows = TenantService.get_workspaces_for_account(db.session, str(ctx.account_id))
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
rows = TenantService.get_workspaces_for_account(db.session, str(auth_data.account_id))
|
||||
|
||||
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
|
||||
|
||||
@ -161,12 +131,9 @@ class WorkspacesApi(Resource):
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>")
|
||||
class WorkspaceByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def get(self, workspace_id: str):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id)
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData):
|
||||
row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id)
|
||||
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
@ -185,21 +152,17 @@ class WorkspaceSwitchApi(Resource):
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role()
|
||||
def post(self, workspace_id: str):
|
||||
ctx = get_auth_ctx()
|
||||
account = _load_account(ctx.account_id)
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
account = _load_account(auth_data.account_id)
|
||||
|
||||
try:
|
||||
TenantService.switch_tenant(account, workspace_id)
|
||||
except AccountNotLinkTenantError:
|
||||
# Membership existed at gate time but Tenant.status != NORMAL or
|
||||
# the row was just removed — treat as not-found.
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id)
|
||||
row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id)
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
tenant, membership = row
|
||||
@ -216,20 +179,15 @@ class WorkspaceMembersApi(Resource):
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(MemberListQuery))
|
||||
@openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role()
|
||||
def get(self, workspace_id: str):
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData):
|
||||
try:
|
||||
query = MemberListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
tenant = _load_tenant(workspace_id)
|
||||
# Members per workspace are bounded by SaaS plan caps (≤50) or EE
|
||||
# license seats (low thousands worst-case), so we materialize and
|
||||
# slice in-memory rather than push pagination into the service —
|
||||
# matches how the rest of the service exposes member lists.
|
||||
members = TenantService.get_tenant_members(tenant)
|
||||
total = len(members)
|
||||
start = (query.page - 1) * query.limit
|
||||
@ -244,13 +202,11 @@ class WorkspaceMembersApi(Resource):
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__])
|
||||
@openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def post(self, workspace_id: str):
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberInvitePayload)
|
||||
ctx = get_auth_ctx()
|
||||
inviter = _load_account(ctx.account_id)
|
||||
inviter = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
|
||||
_check_member_invite_quota(str(tenant.id))
|
||||
@ -297,12 +253,10 @@ class WorkspaceMemberApi(Resource):
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def delete(self, workspace_id: str, member_id: str):
|
||||
ctx = get_auth_ctx()
|
||||
operator = _load_account(ctx.account_id)
|
||||
def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
operator = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
member = AccountService.get_account_by_id(db.session, member_id)
|
||||
if member is None:
|
||||
@ -330,13 +284,11 @@ class WorkspaceMemberRoleApi(Resource):
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__])
|
||||
@openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def put(self, workspace_id: str, member_id: str):
|
||||
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberRoleUpdatePayload)
|
||||
ctx = get_auth_ctx()
|
||||
operator = _load_account(ctx.account_id)
|
||||
operator = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
member = AccountService.get_account_by_id(db.session, member_id)
|
||||
if member is None:
|
||||
|
||||
@ -27,6 +27,7 @@ from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import get_default_root_node_id
|
||||
from core.workflow.nodes.agent_v2.session_cleanup_layer import build_workflow_agent_session_cleanup_layer
|
||||
from core.workflow.system_variables import (
|
||||
build_bootstrap_variables,
|
||||
build_system_variables,
|
||||
@ -239,6 +240,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
workflow_entry.graph_engine.layer(build_workflow_agent_session_cleanup_layer())
|
||||
conversation_variable_layer = ConversationVariablePersistenceLayer(
|
||||
ConversationVariableUpdater(session_factory.get_session_maker())
|
||||
)
|
||||
|
||||
@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import get_default_root_node_id
|
||||
from core.workflow.nodes.agent_v2.session_cleanup_layer import build_workflow_agent_session_cleanup_layer
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@ -166,6 +167,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
workflow_entry.graph_engine.layer(build_workflow_agent_session_cleanup_layer())
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
|
||||
@ -475,6 +475,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
from core.workflow.nodes.agent_v2.file_tenant_validator import UploadFileTenantValidator
|
||||
from core.workflow.nodes.agent_v2.output_failure_orchestrator import OutputFailureOrchestrator
|
||||
from core.workflow.nodes.agent_v2.output_type_checker import PerOutputTypeChecker
|
||||
from core.workflow.nodes.agent_v2.session_store import WorkflowAgentRuntimeSessionStore
|
||||
|
||||
return {
|
||||
"binding_resolver": WorkflowAgentBindingResolver(),
|
||||
@ -494,6 +495,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
# outputs contain no file refs.
|
||||
"type_checker": PerOutputTypeChecker(file_validator=UploadFileTenantValidator()),
|
||||
"failure_orchestrator": OutputFailureOrchestrator(),
|
||||
"session_store": WorkflowAgentRuntimeSessionStore(),
|
||||
}
|
||||
return {
|
||||
"strategy_resolver": self._agent_strategy_resolver,
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
|
||||
from clients.agent_backend import (
|
||||
AgentBackendError,
|
||||
AgentBackendHTTPError,
|
||||
@ -17,11 +20,14 @@ from clients.agent_backend import (
|
||||
AgentBackendStreamInternalEvent,
|
||||
AgentBackendTransportError,
|
||||
AgentBackendValidationError,
|
||||
CleanupLayerSpec,
|
||||
extract_cleanup_layer_specs,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from graphon.entities.pause_reason import SchedulingPause
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
|
||||
from graphon.node_events import NodeEventBase, NodeRunResult, PauseRequestedEvent, StreamCompletedEvent
|
||||
from graphon.nodes.base.node import Node
|
||||
from models.agent_config_entities import WorkflowNodeJobConfig
|
||||
|
||||
@ -40,11 +46,14 @@ from .runtime_request_builder import (
|
||||
WorkflowAgentRuntimeRequestBuilder,
|
||||
WorkflowAgentRuntimeRequestBuildError,
|
||||
)
|
||||
from .session_store import WorkflowAgentRuntimeSessionStore, WorkflowAgentSessionScope
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Stage 4 §5+§7: the terminal events that `_consume_event_stream` may return.
|
||||
# Stream + started events are filtered out before we yield; transport errors
|
||||
@ -74,6 +83,7 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
output_adapter: WorkflowAgentOutputAdapter,
|
||||
type_checker: PerOutputTypeChecker,
|
||||
failure_orchestrator: OutputFailureOrchestrator,
|
||||
session_store: WorkflowAgentRuntimeSessionStore | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
@ -88,6 +98,7 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
self._output_adapter = output_adapter
|
||||
self._type_checker = type_checker
|
||||
self._failure_orchestrator = failure_orchestrator
|
||||
self._session_store = session_store
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
@ -134,6 +145,17 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
"agent_config_snapshot_id": bundle.snapshot.id,
|
||||
"binding_id": bundle.binding.id,
|
||||
}
|
||||
session_scope = WorkflowAgentSessionScope(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
binding_id=bundle.binding.id,
|
||||
agent_id=bundle.agent.id,
|
||||
agent_config_snapshot_id=bundle.snapshot.id,
|
||||
)
|
||||
|
||||
# Stage 4 §4.1 (D-3): use effective outputs so defaults flow through both
|
||||
# the backend request and the post-run type check.
|
||||
@ -147,6 +169,9 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
session_snapshot = None
|
||||
if self._session_store is not None:
|
||||
session_snapshot = self._session_store.load_active_snapshot(session_scope)
|
||||
runtime_request = self._runtime_request_builder.build(
|
||||
WorkflowAgentRuntimeBuildContext(
|
||||
dify_context=dify_ctx,
|
||||
@ -159,6 +184,7 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
agent=bundle.agent,
|
||||
snapshot=bundle.snapshot,
|
||||
attempt=attempt,
|
||||
session_snapshot=session_snapshot,
|
||||
)
|
||||
)
|
||||
except WorkflowAgentRuntimeRequestBuildError as error:
|
||||
@ -221,9 +247,35 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
# Non-success terminal (failed / cancelled / paused) skips per-output
|
||||
# post-processing — the backend itself already failed.
|
||||
if isinstance(terminal_event, AgentBackendRunPausedInternalEvent):
|
||||
self._save_session_snapshot(
|
||||
session_scope=session_scope,
|
||||
backend_run_id=terminal_event.run_id,
|
||||
snapshot=terminal_event.session_snapshot,
|
||||
composition_layer_specs=extract_cleanup_layer_specs(runtime_request.request.composition),
|
||||
metadata=metadata,
|
||||
)
|
||||
yield PauseRequestedEvent(
|
||||
reason=SchedulingPause(
|
||||
message=terminal_event.message
|
||||
or "Agent backend run requested workflow pause for external input."
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Non-success terminal (failed / cancelled) skips per-output
|
||||
# post-processing — the backend itself already failed. We also retire
|
||||
# the local ACTIVE session row so a workflow loop back into the same
|
||||
# Agent node cannot resume from a stale snapshot. The failed agent
|
||||
# backend layers (suspended per ``on_exit``) are left for agent
|
||||
# backend's own GC; this row will no longer be picked up by the
|
||||
# workflow-terminal cleanup layer.
|
||||
if not isinstance(terminal_event, AgentBackendRunSucceededInternalEvent):
|
||||
self._mark_session_cleaned_on_failure(
|
||||
session_scope=session_scope,
|
||||
backend_run_id=terminal_event.run_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=self._output_adapter.build_failure_result(
|
||||
event=terminal_event,
|
||||
@ -234,6 +286,14 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
self._save_session_snapshot(
|
||||
session_scope=session_scope,
|
||||
backend_run_id=terminal_event.run_id,
|
||||
snapshot=terminal_event.session_snapshot,
|
||||
composition_layer_specs=extract_cleanup_layer_specs(runtime_request.request.composition),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# ──── Stage 4: per-output type check ────
|
||||
type_check = self._type_checker.check(
|
||||
declared_outputs=effective_outputs,
|
||||
@ -384,6 +444,75 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
],
|
||||
}
|
||||
|
||||
def _save_session_snapshot(
|
||||
self,
|
||||
*,
|
||||
session_scope: WorkflowAgentSessionScope,
|
||||
backend_run_id: str,
|
||||
snapshot: CompositorSessionSnapshot | None,
|
||||
composition_layer_specs: list[CleanupLayerSpec],
|
||||
metadata: dict[str, Any],
|
||||
) -> None:
|
||||
if self._session_store is None:
|
||||
return
|
||||
try:
|
||||
self._session_store.save_active_snapshot(
|
||||
scope=session_scope,
|
||||
backend_run_id=backend_run_id,
|
||||
snapshot=snapshot,
|
||||
composition_layer_specs=composition_layer_specs,
|
||||
)
|
||||
agent_backend = dict(metadata.get("agent_backend") or {})
|
||||
agent_backend["session_snapshot_persisted"] = snapshot is not None
|
||||
metadata["agent_backend"] = agent_backend
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist workflow Agent runtime session snapshot: "
|
||||
"tenant_id=%s workflow_run_id=%s node_id=%s binding_id=%s agent_id=%s backend_run_id=%s",
|
||||
session_scope.tenant_id,
|
||||
session_scope.workflow_run_id,
|
||||
session_scope.node_id,
|
||||
session_scope.binding_id,
|
||||
session_scope.agent_id,
|
||||
backend_run_id,
|
||||
exc_info=True,
|
||||
)
|
||||
agent_backend = dict(metadata.get("agent_backend") or {})
|
||||
agent_backend["session_snapshot_persisted"] = False
|
||||
agent_backend["session_snapshot_persist_error"] = "workflow_agent_runtime_session_store_error"
|
||||
metadata["agent_backend"] = agent_backend
|
||||
|
||||
def _mark_session_cleaned_on_failure(
|
||||
self,
|
||||
*,
|
||||
session_scope: WorkflowAgentSessionScope,
|
||||
backend_run_id: str,
|
||||
metadata: dict[str, Any],
|
||||
) -> None:
|
||||
if self._session_store is None:
|
||||
return
|
||||
try:
|
||||
self._session_store.mark_cleaned(scope=session_scope, backend_run_id=backend_run_id)
|
||||
agent_backend = dict(metadata.get("agent_backend") or {})
|
||||
agent_backend["session_snapshot_cleaned_on_failure"] = True
|
||||
metadata["agent_backend"] = agent_backend
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to mark workflow Agent runtime session cleaned on agent run failure: "
|
||||
"tenant_id=%s workflow_run_id=%s node_id=%s binding_id=%s agent_id=%s backend_run_id=%s",
|
||||
session_scope.tenant_id,
|
||||
session_scope.workflow_run_id,
|
||||
session_scope.node_id,
|
||||
session_scope.binding_id,
|
||||
session_scope.agent_id,
|
||||
backend_run_id,
|
||||
exc_info=True,
|
||||
)
|
||||
agent_backend = dict(metadata.get("agent_backend") or {})
|
||||
agent_backend["session_snapshot_cleaned_on_failure"] = False
|
||||
agent_backend["session_snapshot_cleanup_error"] = "workflow_agent_runtime_session_store_error"
|
||||
metadata["agent_backend"] = agent_backend
|
||||
|
||||
@staticmethod
|
||||
def _patch_event_with_defaults(
|
||||
event: AgentBackendRunSucceededInternalEvent,
|
||||
|
||||
@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Protocol, cast
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
|
||||
from dify_agent.protocol import CreateRunRequest
|
||||
|
||||
@ -28,6 +29,7 @@ from models.agent_config_entities import (
|
||||
from models.agent_config_entities import (
|
||||
effective_declared_outputs as _effective_declared_outputs,
|
||||
)
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
from .output_failure_orchestrator import retry_idempotency_key
|
||||
from .plugin_tools_builder import WorkflowAgentPluginToolsBuilder, WorkflowAgentPluginToolsBuildError
|
||||
@ -66,6 +68,7 @@ class WorkflowAgentRuntimeBuildContext:
|
||||
# Stage 4 §7 / D-4: 0 for the first run, then incremented per retry. Drives the
|
||||
# idempotency key so the backend treats each retry as a fresh request.
|
||||
attempt: int = 0
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
@ -129,11 +132,14 @@ class WorkflowAgentRuntimeRequestBuilder:
|
||||
request = self._request_builder.build_for_workflow_node(
|
||||
AgentBackendWorkflowNodeRunInput(
|
||||
model=AgentBackendModelConfig(
|
||||
plugin_id=agent_soul.model.plugin_id,
|
||||
model_provider=agent_soul.model.model_provider,
|
||||
plugin_id=self._plugin_daemon_plugin_id(
|
||||
plugin_id=agent_soul.model.plugin_id,
|
||||
model_provider=agent_soul.model.model_provider,
|
||||
),
|
||||
model_provider=self._plugin_daemon_provider_name(agent_soul.model.model_provider),
|
||||
model=agent_soul.model.model,
|
||||
credentials=self._normalize_credentials(credentials),
|
||||
model_settings=cast(dict[str, Any], agent_soul.model.model_settings),
|
||||
model_settings=agent_soul.model.model_settings,
|
||||
),
|
||||
# The execution-context layer is now the only public protocol
|
||||
# carrier for Dify tenant/user/run identifiers. ``user_id`` must
|
||||
@ -158,6 +164,7 @@ class WorkflowAgentRuntimeRequestBuilder:
|
||||
user_prompt=user_prompt,
|
||||
output=self._build_output_config(node_job.declared_outputs),
|
||||
tools=tools_layer,
|
||||
session_snapshot=context.session_snapshot,
|
||||
idempotency_key=self._idempotency_key(context),
|
||||
metadata=metadata,
|
||||
)
|
||||
@ -177,6 +184,20 @@ class WorkflowAgentRuntimeRequestBuilder:
|
||||
return "single_step"
|
||||
return "workflow_run"
|
||||
|
||||
@staticmethod
|
||||
def _plugin_daemon_plugin_id(*, plugin_id: str, model_provider: str) -> str:
|
||||
"""Return the transport plugin id expected by plugin-daemon headers."""
|
||||
if plugin_id.count("/") == 1:
|
||||
return plugin_id
|
||||
if plugin_id:
|
||||
return ModelProviderID(plugin_id).plugin_id
|
||||
return ModelProviderID(model_provider).plugin_id
|
||||
|
||||
@staticmethod
|
||||
def _plugin_daemon_provider_name(model_provider: str) -> str:
|
||||
"""Return the provider name expected by plugin-daemon dispatch payloads."""
|
||||
return ModelProviderID(model_provider).provider_name
|
||||
|
||||
@staticmethod
|
||||
def _idempotency_key(context: WorkflowAgentRuntimeBuildContext) -> str:
|
||||
# Stage 4 §7 / D-4: retries get distinct keys (``...:retry-{attempt}``) so
|
||||
|
||||
247
api/core/workflow/nodes/agent_v2/session_cleanup_layer.py
Normal file
247
api/core/workflow/nodes/agent_v2/session_cleanup_layer.py
Normal file
@ -0,0 +1,247 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from clients.agent_backend import AgentBackendError, AgentBackendRunClient, AgentBackendRunRequestBuilder
|
||||
from clients.agent_backend.factory import create_agent_backend_run_client
|
||||
from configs import dify_config
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .session_store import StoredWorkflowAgentSession, WorkflowAgentRuntimeSessionStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Upper bound on how long a cleanup-only run is allowed to settle before the
|
||||
# layer gives up and leaves the row ACTIVE so it can be retried later. Cleanup
|
||||
# work is mostly local agent-backend bookkeeping (no LLM inference), so 30s is
|
||||
# generous; a hung backend should never block workflow termination beyond this.
|
||||
_CLEANUP_WAIT_TIMEOUT_SECONDS = 30.0
|
||||
|
||||
|
||||
class WorkflowAgentSessionCleanupLayer(GraphEngineLayer):
|
||||
"""Retires workflow Agent session snapshots when a workflow reaches a terminal state.
|
||||
|
||||
Implementation notes — there are two failure modes the cleanup path has to
|
||||
avoid simultaneously:
|
||||
|
||||
1. The agenton compositor on the agent-backend side validates the cleanup
|
||||
request's session snapshot against the replayed composition before
|
||||
running any lifecycle hook. If the snapshot's layer names diverge from
|
||||
the composition, the run fails asynchronously with ``run_failed`` — but
|
||||
the initial ``POST /runs`` already returned 202, so the API side has no
|
||||
visibility of the failure unless it waits for terminal status. The
|
||||
``composition_layer_specs`` persistence in A.1–A.4 plus the
|
||||
``_filter_snapshot_to_specs`` shape in ``build_cleanup_request`` keeps
|
||||
the two name lists in sync.
|
||||
|
||||
2. The current agent backend's ``runner.py::_run_agent`` always invokes
|
||||
``run.get_layer("llm")`` and the structured-output / history validators
|
||||
before exiting any slot — there is no ``purpose: "cleanup"`` branch
|
||||
yet. A truly cleanup-only request (no LLM layer) therefore still
|
||||
crashes inside the runner with ``Layer 'llm' is not defined in this
|
||||
compositor run.``. Until the backend grows a cleanup-only purpose,
|
||||
this layer **does not issue an HTTP cleanup run**: it simply retires
|
||||
the local snapshot row so stale state cannot be re-resumed, and lets
|
||||
the agent backend's own retention TTL release the suspended layers.
|
||||
|
||||
The HTTP-cleanup machinery (``build_cleanup_request`` + ``wait_run``) is
|
||||
intentionally still wired into the request builder + integration tests so
|
||||
that when the agent backend supports cleanup runs we can flip the switch
|
||||
here with a one-line change (see ``_HTTP_CLEANUP_SUPPORTED``).
|
||||
"""
|
||||
|
||||
# Flip to True once dify-agent's runner has a ``purpose=cleanup`` branch
|
||||
# that skips the LLM/output/user-prompt invariants. Until then we only
|
||||
# update the local row; the spec list is still persisted so the future
|
||||
# HTTP cleanup path has everything it needs.
|
||||
_HTTP_CLEANUP_SUPPORTED: bool = False
|
||||
|
||||
_TERMINAL_EVENTS = (
|
||||
GraphRunSucceededEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunAbortedEvent,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_store: WorkflowAgentRuntimeSessionStore,
|
||||
request_builder: AgentBackendRunRequestBuilder,
|
||||
agent_backend_client: AgentBackendRunClient | None,
|
||||
cleanup_wait_timeout_seconds: float = _CLEANUP_WAIT_TIMEOUT_SECONDS,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._session_store = session_store
|
||||
self._request_builder = request_builder
|
||||
self._agent_backend_client = agent_backend_client
|
||||
self._cleanup_wait_timeout_seconds = cleanup_wait_timeout_seconds
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
return
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if not isinstance(event, self._TERMINAL_EVENTS):
|
||||
return
|
||||
workflow_run_id = get_system_text(
|
||||
self.graph_runtime_state.variable_pool,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID,
|
||||
)
|
||||
if not workflow_run_id:
|
||||
logger.warning("Skipping workflow Agent session cleanup: workflow_run_id is missing.")
|
||||
return
|
||||
|
||||
for stored_session in self._session_store.list_active_sessions(workflow_run_id=workflow_run_id):
|
||||
self._cleanup_session(stored_session)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
return
|
||||
|
||||
def _cleanup_session(self, stored_session: StoredWorkflowAgentSession) -> None:
|
||||
scope = stored_session.scope
|
||||
if not self._HTTP_CLEANUP_SUPPORTED:
|
||||
# Agent backend has no cleanup-only run mode yet (see class
|
||||
# docstring). Retire the local row so future re-entries do not
|
||||
# resume from stale state, and let the backend's retention TTL
|
||||
# release the suspended layers on its own schedule.
|
||||
logger.info(
|
||||
"Workflow Agent session retired locally; HTTP cleanup is disabled "
|
||||
"until the agent backend supports a cleanup-only run mode. "
|
||||
"workflow_run_id=%s node_id=%s binding_id=%s agent_id=%s previous_run_id=%s",
|
||||
scope.workflow_run_id,
|
||||
scope.node_id,
|
||||
scope.binding_id,
|
||||
scope.agent_id,
|
||||
stored_session.backend_run_id,
|
||||
)
|
||||
self._session_store.mark_cleaned(scope=scope, backend_run_id=stored_session.backend_run_id)
|
||||
return
|
||||
|
||||
if self._agent_backend_client is None:
|
||||
# HTTP cleanup was enabled by the caller but no client was wired
|
||||
# in (e.g. the API runs without AGENT_BACKEND_BASE_URL configured).
|
||||
# Leave the row ACTIVE so an operator restart with proper config
|
||||
# can drive the cleanup; do not silently retire it.
|
||||
logger.warning(
|
||||
"Skipping Agent backend cleanup: HTTP cleanup is enabled but no agent "
|
||||
"backend client is wired in. workflow_run_id=%s node_id=%s agent_id=%s",
|
||||
scope.workflow_run_id,
|
||||
scope.node_id,
|
||||
scope.agent_id,
|
||||
)
|
||||
return
|
||||
|
||||
if not stored_session.composition_layer_specs:
|
||||
# Sessions persisted before A.1 landed do not carry the spec list,
|
||||
# so we cannot replay a valid cleanup composition. Leave the row
|
||||
# ACTIVE and warn so the absence shows up in observability rather
|
||||
# than being silently swallowed by a doomed cleanup run.
|
||||
logger.warning(
|
||||
"Skipping Agent backend cleanup: no composition_layer_specs persisted. "
|
||||
"workflow_run_id=%s node_id=%s agent_id=%s",
|
||||
scope.workflow_run_id,
|
||||
scope.node_id,
|
||||
scope.agent_id,
|
||||
)
|
||||
return
|
||||
|
||||
request = self._request_builder.build_cleanup_request(
|
||||
session_snapshot=stored_session.session_snapshot,
|
||||
composition_layer_specs=stored_session.composition_layer_specs,
|
||||
idempotency_key=f"{scope.workflow_run_id}:{scope.node_id}:{scope.binding_id}:agent-session-cleanup",
|
||||
metadata={
|
||||
"tenant_id": scope.tenant_id,
|
||||
"app_id": scope.app_id,
|
||||
"workflow_id": scope.workflow_id,
|
||||
"workflow_run_id": scope.workflow_run_id,
|
||||
"node_id": scope.node_id,
|
||||
"node_execution_id": scope.node_execution_id,
|
||||
"binding_id": scope.binding_id,
|
||||
"agent_id": scope.agent_id,
|
||||
"agent_config_snapshot_id": scope.agent_config_snapshot_id,
|
||||
"previous_agent_backend_run_id": stored_session.backend_run_id,
|
||||
},
|
||||
)
|
||||
try:
|
||||
response = self._agent_backend_client.create_run(request)
|
||||
except AgentBackendError:
|
||||
logger.warning(
|
||||
"Agent backend session cleanup request failed: workflow_run_id=%s node_id=%s agent_id=%s",
|
||||
scope.workflow_run_id,
|
||||
scope.node_id,
|
||||
scope.agent_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
status_response = self._agent_backend_client.wait_run(
|
||||
response.run_id, timeout_seconds=self._cleanup_wait_timeout_seconds
|
||||
)
|
||||
except AgentBackendError:
|
||||
logger.warning(
|
||||
"Agent backend session cleanup wait_run failed: "
|
||||
"workflow_run_id=%s node_id=%s agent_id=%s cleanup_run_id=%s",
|
||||
scope.workflow_run_id,
|
||||
scope.node_id,
|
||||
scope.agent_id,
|
||||
response.run_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
if status_response.status != "succeeded":
|
||||
logger.warning(
|
||||
"Agent backend session cleanup did not succeed: status=%s error=%s "
|
||||
"workflow_run_id=%s node_id=%s agent_id=%s cleanup_run_id=%s",
|
||||
status_response.status,
|
||||
status_response.error,
|
||||
scope.workflow_run_id,
|
||||
scope.node_id,
|
||||
scope.agent_id,
|
||||
response.run_id,
|
||||
)
|
||||
return
|
||||
|
||||
self._session_store.mark_cleaned(scope=scope, backend_run_id=response.run_id)
|
||||
|
||||
|
||||
def build_workflow_agent_session_cleanup_layer() -> WorkflowAgentSessionCleanupLayer:
|
||||
"""Wire the cleanup layer with the standard production dependencies.
|
||||
|
||||
The agent backend client is constructed only when ``AGENT_BACKEND_BASE_URL``
|
||||
is configured (or the deterministic fake is explicitly enabled). When
|
||||
neither is set — for example unit tests that bring up the workflow runner
|
||||
without an Agent node — we pass ``None`` so the layer stays harmless. With
|
||||
``_HTTP_CLEANUP_SUPPORTED = False`` the local-retire branch never touches
|
||||
the client anyway, but keeping it ``None`` avoids importing httpx and lets
|
||||
test harnesses skip backend configuration.
|
||||
"""
|
||||
agent_backend_client: AgentBackendRunClient | None
|
||||
if dify_config.AGENT_BACKEND_USE_FAKE or dify_config.AGENT_BACKEND_BASE_URL:
|
||||
agent_backend_client = create_agent_backend_run_client(
|
||||
base_url=dify_config.AGENT_BACKEND_BASE_URL,
|
||||
use_fake=dify_config.AGENT_BACKEND_USE_FAKE,
|
||||
fake_scenario=dify_config.AGENT_BACKEND_FAKE_SCENARIO,
|
||||
)
|
||||
else:
|
||||
agent_backend_client = None
|
||||
|
||||
return WorkflowAgentSessionCleanupLayer(
|
||||
session_store=WorkflowAgentRuntimeSessionStore(),
|
||||
request_builder=AgentBackendRunRequestBuilder(),
|
||||
agent_backend_client=agent_backend_client,
|
||||
)
|
||||
179
api/core/workflow/nodes/agent_v2/session_store.py
Normal file
179
api/core/workflow/nodes/agent_v2/session_store.py
Normal file
@ -0,0 +1,179 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
||||
from clients.agent_backend.request_builder import CleanupLayerSpec
|
||||
from core.db.session_factory import session_factory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.agent import (
|
||||
WorkflowAgentRuntimeSession,
|
||||
WorkflowAgentRuntimeSessionStatus,
|
||||
)
|
||||
|
||||
_SPECS_ADAPTER: TypeAdapter[list[CleanupLayerSpec]] = TypeAdapter(list[CleanupLayerSpec])
|
||||
|
||||
|
||||
def _serialize_specs(specs: list[CleanupLayerSpec]) -> str:
|
||||
return _SPECS_ADAPTER.dump_json(specs).decode()
|
||||
|
||||
|
||||
def _deserialize_specs(value: str | None) -> list[CleanupLayerSpec]:
|
||||
if not value:
|
||||
return []
|
||||
return _SPECS_ADAPTER.validate_json(value)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class WorkflowAgentSessionScope:
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
workflow_run_id: str | None
|
||||
node_id: str
|
||||
node_execution_id: str
|
||||
binding_id: str
|
||||
agent_id: str
|
||||
agent_config_snapshot_id: str
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class StoredWorkflowAgentSession:
|
||||
scope: WorkflowAgentSessionScope
|
||||
session_snapshot: CompositorSessionSnapshot
|
||||
backend_run_id: str | None
|
||||
composition_layer_specs: list[CleanupLayerSpec] = field(default_factory=list)
|
||||
|
||||
|
||||
class WorkflowAgentRuntimeSessionStore:
|
||||
"""Stores Agent backend session snapshots for workflow Agent node re-entry."""
|
||||
|
||||
def load_active_snapshot(self, scope: WorkflowAgentSessionScope) -> CompositorSessionSnapshot | None:
|
||||
if scope.workflow_run_id is None:
|
||||
return None
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
row = session.scalar(
|
||||
select(WorkflowAgentRuntimeSession).where(
|
||||
WorkflowAgentRuntimeSession.tenant_id == scope.tenant_id,
|
||||
WorkflowAgentRuntimeSession.workflow_run_id == scope.workflow_run_id,
|
||||
WorkflowAgentRuntimeSession.node_id == scope.node_id,
|
||||
WorkflowAgentRuntimeSession.binding_id == scope.binding_id,
|
||||
WorkflowAgentRuntimeSession.agent_id == scope.agent_id,
|
||||
WorkflowAgentRuntimeSession.status == WorkflowAgentRuntimeSessionStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return CompositorSessionSnapshot.model_validate_json(row.session_snapshot)
|
||||
|
||||
def list_active_sessions(self, *, workflow_run_id: str) -> list[StoredWorkflowAgentSession]:
|
||||
with session_factory.create_session() as session:
|
||||
rows = session.scalars(
|
||||
select(WorkflowAgentRuntimeSession).where(
|
||||
WorkflowAgentRuntimeSession.workflow_run_id == workflow_run_id,
|
||||
WorkflowAgentRuntimeSession.status == WorkflowAgentRuntimeSessionStatus.ACTIVE,
|
||||
)
|
||||
).all()
|
||||
return [
|
||||
StoredWorkflowAgentSession(
|
||||
scope=WorkflowAgentSessionScope(
|
||||
tenant_id=row.tenant_id,
|
||||
app_id=row.app_id,
|
||||
workflow_id=row.workflow_id,
|
||||
workflow_run_id=row.workflow_run_id,
|
||||
node_id=row.node_id,
|
||||
node_execution_id=row.node_execution_id or "",
|
||||
binding_id=row.binding_id,
|
||||
agent_id=row.agent_id,
|
||||
agent_config_snapshot_id=row.agent_config_snapshot_id,
|
||||
),
|
||||
session_snapshot=CompositorSessionSnapshot.model_validate_json(row.session_snapshot),
|
||||
backend_run_id=row.backend_run_id,
|
||||
composition_layer_specs=_deserialize_specs(row.composition_layer_specs),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def save_active_snapshot(
|
||||
self,
|
||||
*,
|
||||
scope: WorkflowAgentSessionScope,
|
||||
backend_run_id: str,
|
||||
snapshot: CompositorSessionSnapshot | None,
|
||||
composition_layer_specs: list[CleanupLayerSpec],
|
||||
) -> None:
|
||||
if scope.workflow_run_id is None or snapshot is None:
|
||||
return
|
||||
|
||||
snapshot_json = snapshot.model_dump_json()
|
||||
specs_json = _serialize_specs(composition_layer_specs)
|
||||
with session_factory.create_session() as session:
|
||||
row = session.scalar(
|
||||
select(WorkflowAgentRuntimeSession).where(
|
||||
WorkflowAgentRuntimeSession.tenant_id == scope.tenant_id,
|
||||
WorkflowAgentRuntimeSession.workflow_run_id == scope.workflow_run_id,
|
||||
WorkflowAgentRuntimeSession.node_id == scope.node_id,
|
||||
WorkflowAgentRuntimeSession.binding_id == scope.binding_id,
|
||||
WorkflowAgentRuntimeSession.agent_id == scope.agent_id,
|
||||
)
|
||||
)
|
||||
if row is None:
|
||||
row = WorkflowAgentRuntimeSession(
|
||||
tenant_id=scope.tenant_id,
|
||||
app_id=scope.app_id,
|
||||
workflow_id=scope.workflow_id,
|
||||
workflow_run_id=scope.workflow_run_id,
|
||||
node_id=scope.node_id,
|
||||
node_execution_id=scope.node_execution_id,
|
||||
binding_id=scope.binding_id,
|
||||
agent_id=scope.agent_id,
|
||||
agent_config_snapshot_id=scope.agent_config_snapshot_id,
|
||||
backend_run_id=backend_run_id,
|
||||
session_snapshot=snapshot_json,
|
||||
composition_layer_specs=specs_json,
|
||||
status=WorkflowAgentRuntimeSessionStatus.ACTIVE,
|
||||
)
|
||||
session.add(row)
|
||||
else:
|
||||
row.node_execution_id = scope.node_execution_id
|
||||
row.agent_config_snapshot_id = scope.agent_config_snapshot_id
|
||||
row.backend_run_id = backend_run_id
|
||||
row.session_snapshot = snapshot_json
|
||||
row.composition_layer_specs = specs_json
|
||||
row.status = WorkflowAgentRuntimeSessionStatus.ACTIVE
|
||||
row.cleaned_at = None
|
||||
session.commit()
|
||||
|
||||
def mark_cleaned(self, *, scope: WorkflowAgentSessionScope, backend_run_id: str | None = None) -> None:
|
||||
if scope.workflow_run_id is None:
|
||||
return
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
row = session.scalar(
|
||||
select(WorkflowAgentRuntimeSession).where(
|
||||
WorkflowAgentRuntimeSession.tenant_id == scope.tenant_id,
|
||||
WorkflowAgentRuntimeSession.workflow_run_id == scope.workflow_run_id,
|
||||
WorkflowAgentRuntimeSession.node_id == scope.node_id,
|
||||
WorkflowAgentRuntimeSession.binding_id == scope.binding_id,
|
||||
WorkflowAgentRuntimeSession.agent_id == scope.agent_id,
|
||||
WorkflowAgentRuntimeSession.status == WorkflowAgentRuntimeSessionStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
if row is None:
|
||||
return
|
||||
if backend_run_id is not None:
|
||||
row.backend_run_id = backend_run_id
|
||||
row.status = WorkflowAgentRuntimeSessionStatus.CLEANED
|
||||
row.cleaned_at = naive_utc_now()
|
||||
session.commit()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"StoredWorkflowAgentSession",
|
||||
"WorkflowAgentRuntimeSessionStore",
|
||||
"WorkflowAgentSessionScope",
|
||||
]
|
||||
@ -43,6 +43,11 @@ class SubjectType(StrEnum):
|
||||
EXTERNAL_SSO = "external_sso"
|
||||
|
||||
|
||||
class TokenType(StrEnum):
|
||||
OAUTH_ACCOUNT = "oauth_account"
|
||||
OAUTH_EXTERNAL_SSO = "oauth_external_sso"
|
||||
|
||||
|
||||
class Scope(StrEnum):
|
||||
"""Catalog of bearer scopes recognised by the openapi surface.
|
||||
|
||||
@ -55,6 +60,8 @@ class Scope(StrEnum):
|
||||
APPS_READ = "apps:read"
|
||||
APPS_READ_PERMITTED_EXTERNAL = "apps:read:permitted-external"
|
||||
APPS_RUN = "apps:run"
|
||||
WORKSPACE_READ = "workspace:read"
|
||||
WORKSPACE_WRITE = "workspace:write"
|
||||
|
||||
|
||||
class Accepts(StrEnum):
|
||||
@ -77,7 +84,7 @@ _SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
|
||||
class AuthContext:
|
||||
"""Per-request identity published via :data:`_auth_ctx_var`
|
||||
(see :func:`set_auth_ctx` / :func:`get_auth_ctx`). ``scopes`` /
|
||||
``subject_type`` / ``source`` come from the TokenKind, not the DB —
|
||||
``subject_type`` / ``token_type`` come from the TokenKind, not the DB —
|
||||
corrupt rows can't elevate scope.
|
||||
|
||||
`verified_tenants` is a snapshot of the Layer-0 verdict cache at
|
||||
@ -92,7 +99,7 @@ class AuthContext:
|
||||
client_id: str | None
|
||||
scopes: frozenset[Scope]
|
||||
token_id: uuid.UUID
|
||||
source: str
|
||||
token_type: TokenType
|
||||
expires_at: datetime | None
|
||||
token_hash: str
|
||||
verified_tenants: dict[str, bool] = field(default_factory=dict)
|
||||
@ -180,7 +187,7 @@ class TokenKind:
|
||||
prefix: str
|
||||
subject_type: SubjectType
|
||||
scopes: frozenset[Scope]
|
||||
source: str
|
||||
token_type: TokenType
|
||||
resolver: Resolver
|
||||
|
||||
def matches(self, token: str) -> bool:
|
||||
@ -291,7 +298,7 @@ class BearerAuthenticator:
|
||||
client_id=row.client_id,
|
||||
scopes=kind.scopes,
|
||||
token_id=row.token_id,
|
||||
source=kind.source,
|
||||
token_type=kind.token_type,
|
||||
expires_at=row.expires_at,
|
||||
token_hash=token_hash,
|
||||
verified_tenants=dict(row.verified_tenants),
|
||||
@ -483,7 +490,7 @@ def check_workspace_membership(
|
||||
account_id: uuid.UUID | str,
|
||||
tenant_id: str,
|
||||
token_hash: str,
|
||||
cached_verdicts: dict[str, bool],
|
||||
membership_cache: dict[str, bool],
|
||||
) -> None:
|
||||
"""Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow.
|
||||
|
||||
@ -492,7 +499,7 @@ def check_workspace_membership(
|
||||
short-circuiting on EE / SSO subjects before invoking — this function
|
||||
runs the membership + active-status checks unconditionally.
|
||||
"""
|
||||
cached = cached_verdicts.get(tenant_id)
|
||||
cached = membership_cache.get(tenant_id)
|
||||
if cached is True:
|
||||
return
|
||||
if cached is False:
|
||||
@ -530,7 +537,7 @@ def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None:
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=tenant_id,
|
||||
token_hash=ctx.token_hash,
|
||||
cached_verdicts=ctx.verified_tenants,
|
||||
membership_cache=ctx.verified_tenants,
|
||||
)
|
||||
|
||||
|
||||
@ -664,14 +671,14 @@ def build_registry(session_factory, redis_client) -> TokenKindRegistry:
|
||||
prefix=account.prefix,
|
||||
subject_type=account.subject_type,
|
||||
scopes=account.scopes,
|
||||
source="oauth_account",
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
resolver=oauth.for_account(),
|
||||
),
|
||||
TokenKind(
|
||||
prefix=external.prefix,
|
||||
subject_type=external.subject_type,
|
||||
scopes=external.scopes,
|
||||
source="oauth_external_sso",
|
||||
token_type=TokenType.OAUTH_EXTERNAL_SSO,
|
||||
resolver=oauth.for_external_sso(),
|
||||
),
|
||||
]
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""add agent domain models
|
||||
|
||||
Revision ID: c6a9f4b12d3e
|
||||
Revises: f6a7b8c9d012
|
||||
Revises: a4f2d8c9b731
|
||||
Create Date: 2026-05-18 13:30:00.000000
|
||||
|
||||
"""
|
||||
@ -13,7 +13,7 @@ import models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c6a9f4b12d3e"
|
||||
down_revision = "f6a7b8c9d012"
|
||||
down_revision = "a4f2d8c9b731"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""add learn dify flag to recommended apps
|
||||
|
||||
Revision ID: f5e8a9c0d2b3
|
||||
Revises: c6a9f4b12d3e
|
||||
Revises: a4f2d8c9b731
|
||||
Create Date: 2026-05-18 15:00:00.000000
|
||||
|
||||
"""
|
||||
@ -11,7 +11,7 @@ from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f5e8a9c0d2b3"
|
||||
down_revision = "c6a9f4b12d3e"
|
||||
down_revision = "a4f2d8c9b731"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""add agent config revisions
|
||||
|
||||
Revision ID: f8b6b7e9c421
|
||||
Revises: f5e8a9c0d2b3
|
||||
Revises: c6a9f4b12d3e
|
||||
Create Date: 2026-05-19 10:00:00.000000
|
||||
|
||||
"""
|
||||
@ -13,7 +13,7 @@ import models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f8b6b7e9c421"
|
||||
down_revision = "f5e8a9c0d2b3"
|
||||
down_revision = "c6a9f4b12d3e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
@ -0,0 +1,90 @@
|
||||
"""add workflow agent runtime sessions
|
||||
|
||||
Revision ID: 7885bd53f9a9
|
||||
Revises: d4a5e1f3c9b7
|
||||
Create Date: 2026-05-27 09:53:54.711805
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7885bd53f9a9"
|
||||
down_revision = "d4a5e1f3c9b7"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _is_pg() -> bool:
|
||||
return op.get_bind().dialect.name == "postgresql"
|
||||
|
||||
|
||||
def _uuid_column(name: str, *, nullable: bool = False, primary_key: bool = False) -> sa.Column:
|
||||
"""Match the ``uuidv7()`` default that other tables on Postgres rely on,
|
||||
while staying portable on MySQL where the ORM supplies the id."""
|
||||
kwargs: dict[str, object] = {"nullable": nullable, "primary_key": primary_key}
|
||||
if primary_key and _is_pg():
|
||||
kwargs["server_default"] = sa.text("uuidv7()")
|
||||
return sa.Column(name, models.types.StringUUID(), **kwargs)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"workflow_agent_runtime_sessions",
|
||||
_uuid_column("id", primary_key=True),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("workflow_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("node_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("node_execution_id", sa.String(length=255), nullable=True),
|
||||
sa.Column("binding_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("agent_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("agent_config_snapshot_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("backend_run_id", sa.String(length=255), nullable=True),
|
||||
sa.Column("session_snapshot", models.types.LongText(), nullable=False),
|
||||
# MySQL rejects ``server_default`` on TEXT/BLOB columns. The JSON
|
||||
# payload is always populated at the ORM layer via
|
||||
# ``WorkflowAgentRuntimeSessionStore.save_active_snapshot`` so the
|
||||
# missing DB-level default cannot leave new rows uninitialized.
|
||||
sa.Column("composition_layer_specs", models.types.LongText(), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.String(length=32),
|
||||
server_default=sa.text("'active'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("cleaned_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("workflow_agent_runtime_session_pkey")),
|
||||
sa.UniqueConstraint(
|
||||
"tenant_id",
|
||||
"workflow_run_id",
|
||||
"node_id",
|
||||
"binding_id",
|
||||
"agent_id",
|
||||
name=op.f("workflow_agent_runtime_session_scope_unique"),
|
||||
),
|
||||
)
|
||||
with op.batch_alter_table("workflow_agent_runtime_sessions", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
"workflow_agent_runtime_session_lookup_idx",
|
||||
["tenant_id", "workflow_run_id", "node_id", "status"],
|
||||
unique=False,
|
||||
)
|
||||
batch_op.create_index(
|
||||
"workflow_agent_runtime_session_backend_run_idx",
|
||||
["backend_run_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
with op.batch_alter_table("workflow_agent_runtime_sessions", schema=None) as batch_op:
|
||||
batch_op.drop_index("workflow_agent_runtime_session_backend_run_idx")
|
||||
batch_op.drop_index("workflow_agent_runtime_session_lookup_idx")
|
||||
op.drop_table("workflow_agent_runtime_sessions")
|
||||
@ -20,6 +20,8 @@ from .agent import (
|
||||
AgentStatus,
|
||||
WorkflowAgentBindingType,
|
||||
WorkflowAgentNodeBinding,
|
||||
WorkflowAgentRuntimeSession,
|
||||
WorkflowAgentRuntimeSessionStatus,
|
||||
)
|
||||
from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
from .comment import (
|
||||
@ -235,6 +237,8 @@ __all__ = [
|
||||
"Workflow",
|
||||
"WorkflowAgentBindingType",
|
||||
"WorkflowAgentNodeBinding",
|
||||
"WorkflowAgentRuntimeSession",
|
||||
"WorkflowAgentRuntimeSessionStatus",
|
||||
"WorkflowAppLog",
|
||||
"WorkflowAppLogCreatedFrom",
|
||||
"WorkflowArchiveLog",
|
||||
|
||||
@ -92,6 +92,15 @@ class WorkflowAgentBindingType(StrEnum):
|
||||
INLINE_AGENT = "inline_agent"
|
||||
|
||||
|
||||
class WorkflowAgentRuntimeSessionStatus(StrEnum):
|
||||
"""Lifecycle state of an Agent backend session snapshot owned by a workflow run."""
|
||||
|
||||
# Snapshot can be reused by a later Agent run in the same workflow run.
|
||||
ACTIVE = "active"
|
||||
# Snapshot has been retired and must not be submitted to Agent backend again.
|
||||
CLEANED = "cleaned"
|
||||
|
||||
|
||||
class Agent(DefaultFieldsMixin, Base):
|
||||
"""Workspace-scoped Agent identity used by Agent Roster and workflow-only agents."""
|
||||
|
||||
@ -273,3 +282,56 @@ class WorkflowAgentNodeBinding(DefaultFieldsMixin, Base):
|
||||
if isinstance(self.node_job_config, str):
|
||||
return json.loads(self.node_job_config)
|
||||
return dict(self.node_job_config)
|
||||
|
||||
|
||||
class WorkflowAgentRuntimeSession(DefaultFieldsMixin, Base):
|
||||
"""Persisted Agent backend session snapshot for one workflow Agent node execution scope.
|
||||
|
||||
The snapshot is runtime state returned by Agent backend. It is intentionally
|
||||
separate from Agent Soul snapshots and workflow node-job config.
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_agent_runtime_sessions"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_agent_runtime_session_pkey"),
|
||||
UniqueConstraint(
|
||||
"tenant_id",
|
||||
"workflow_run_id",
|
||||
"node_id",
|
||||
"binding_id",
|
||||
"agent_id",
|
||||
name="workflow_agent_runtime_session_scope_unique",
|
||||
),
|
||||
Index(
|
||||
"workflow_agent_runtime_session_lookup_idx",
|
||||
"tenant_id",
|
||||
"workflow_run_id",
|
||||
"node_id",
|
||||
"status",
|
||||
),
|
||||
Index("workflow_agent_runtime_session_backend_run_idx", "backend_run_id"),
|
||||
)
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
node_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
node_execution_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
binding_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
agent_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
agent_config_snapshot_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
backend_run_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
session_snapshot: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
# JSON-encoded list of ``WorkflowAgentSessionLayerSpec`` ({name, type, deps,
|
||||
# config}). Drives Agent backend cleanup-only runs: the agenton compositor
|
||||
# rejects a session snapshot whose layer names do not match the cleanup
|
||||
# composition, so we must replay the same layer graph (minus credential-
|
||||
# bearing plugin layers) when issuing the cleanup request.
|
||||
composition_layer_specs: Mapped[str] = mapped_column(LongText, nullable=False, server_default="[]")
|
||||
status: Mapped[WorkflowAgentRuntimeSessionStatus] = mapped_column(
|
||||
EnumText(WorkflowAgentRuntimeSessionStatus, length=32),
|
||||
nullable=False,
|
||||
default=WorkflowAgentRuntimeSessionStatus.ACTIVE,
|
||||
)
|
||||
cleaned_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
@ -0,0 +1,84 @@
|
||||
"""
|
||||
Integration tests for delete_account_task.
|
||||
|
||||
These tests keep billing and email dispatch mocked, but exercise the account
|
||||
lookup through the real Testcontainers PostgreSQL session factory instead of a
|
||||
patched session_factory mock.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account
|
||||
from tasks.delete_account_task import delete_account_task
|
||||
|
||||
|
||||
def _create_account(db_session: Session, *, email: str = "user@example.com") -> Account:
|
||||
account = Account(
|
||||
name=f"account-{uuid4()}",
|
||||
email=email,
|
||||
)
|
||||
db_session.add(account)
|
||||
db_session.commit()
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_dependencies(mocker):
|
||||
billing_service = mocker.patch("tasks.delete_account_task.BillingService")
|
||||
mail_task = mocker.patch("tasks.delete_account_task.send_deletion_success_task")
|
||||
return billing_service, mail_task
|
||||
|
||||
|
||||
def test_billing_enabled_account_exists_calls_billing_and_sends_email(
|
||||
db_session_with_containers: Session, mock_external_dependencies, mocker
|
||||
) -> None:
|
||||
billing_service, mail_task = mock_external_dependencies
|
||||
account = _create_account(db_session_with_containers, email="a@b.com")
|
||||
mocker.patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True)
|
||||
|
||||
delete_account_task(account.id)
|
||||
|
||||
billing_service.delete_account.assert_called_once_with(account.id)
|
||||
mail_task.delay.assert_called_once_with(account.email)
|
||||
|
||||
|
||||
def test_billing_disabled_account_exists_sends_email_only(
|
||||
db_session_with_containers: Session, mock_external_dependencies, mocker
|
||||
) -> None:
|
||||
billing_service, mail_task = mock_external_dependencies
|
||||
account = _create_account(db_session_with_containers, email="x@y.com")
|
||||
mocker.patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", False)
|
||||
|
||||
delete_account_task(account.id)
|
||||
|
||||
billing_service.delete_account.assert_not_called()
|
||||
mail_task.delay.assert_called_once_with(account.email)
|
||||
|
||||
|
||||
def test_billing_enabled_account_not_found_calls_billing_no_email(mock_external_dependencies, mocker, caplog) -> None:
|
||||
billing_service, mail_task = mock_external_dependencies
|
||||
account_id = str(uuid4())
|
||||
mocker.patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True)
|
||||
|
||||
delete_account_task(account_id)
|
||||
|
||||
billing_service.delete_account.assert_called_once_with(account_id)
|
||||
mail_task.delay.assert_not_called()
|
||||
assert any("not found" in record.getMessage().lower() for record in caplog.records)
|
||||
|
||||
|
||||
def test_billing_delete_raises_propagates_and_no_email(
|
||||
db_session_with_containers: Session, mock_external_dependencies, mocker
|
||||
) -> None:
|
||||
billing_service, mail_task = mock_external_dependencies
|
||||
account = _create_account(db_session_with_containers, email="err@example.com")
|
||||
billing_service.delete_account.side_effect = RuntimeError("billing down")
|
||||
mocker.patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True)
|
||||
|
||||
with pytest.raises(RuntimeError, match="billing down"):
|
||||
delete_account_task(account.id)
|
||||
|
||||
mail_task.delay.assert_not_called()
|
||||
@ -0,0 +1,134 @@
|
||||
"""Integration test for the cleanup request against the real agenton compositor.
|
||||
|
||||
The bug fixed by A+D was invisible to unit tests that use ``FakeAgentBackendRunClient``
|
||||
because the fake client never runs agenton's ``_validate_session_snapshot``. This
|
||||
test plugs a cleanup request through the real ``Compositor`` (with the same
|
||||
providers the agent backend wires in production) so that the snapshot-vs-
|
||||
composition name-order check would fail loudly if the cleanup builder ever
|
||||
regressed back to the empty-composition shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from agenton.compositor import Compositor, CompositorSessionSnapshot, LayerProvider
|
||||
from agenton.compositor.schemas import LayerSessionSnapshot
|
||||
from agenton.layers.base import LifecycleState
|
||||
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID
|
||||
from agenton_collections.layers.plain.basic import PromptLayer
|
||||
from agenton_collections.layers.pydantic_ai import PYDANTIC_AI_HISTORY_LAYER_TYPE_ID, PydanticAIHistoryLayer
|
||||
|
||||
from clients.agent_backend import AgentBackendRunRequestBuilder, CleanupLayerSpec
|
||||
|
||||
|
||||
def test_cleanup_request_passes_agenton_snapshot_validation():
|
||||
"""The cleanup request's composition layer names must match the (filtered)
|
||||
snapshot's layer names exactly — agenton's compositor enforces this and
|
||||
the agent backend rejects mismatches as ``run_failed`` asynchronously,
|
||||
which is the trap A/D fixed."""
|
||||
# Persisted (non-plugin) layer specs — these are what cleanup will replay.
|
||||
# We exclude the dify.execution_context layer from this integration check
|
||||
# because its real provider needs a plugin-daemon HTTP client; the cleanup
|
||||
# validation we are exercising is the snapshot-vs-composition name check,
|
||||
# which is purely structural and does not depend on which non-plugin layer
|
||||
# types appear.
|
||||
persisted_specs = [
|
||||
CleanupLayerSpec(
|
||||
name="workflow_node_job_prompt",
|
||||
type=PLAIN_PROMPT_LAYER_TYPE_ID,
|
||||
config={"prefix": "Do the cleanup."},
|
||||
),
|
||||
CleanupLayerSpec(name="history", type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID),
|
||||
]
|
||||
# Saved snapshot still carries the LLM layer entry — cleanup's
|
||||
# ``_filter_snapshot_to_specs`` must drop it so names match.
|
||||
full_snapshot = CompositorSessionSnapshot(
|
||||
layers=[
|
||||
LayerSessionSnapshot(
|
||||
name="workflow_node_job_prompt",
|
||||
lifecycle_state=LifecycleState.SUSPENDED,
|
||||
runtime_state={},
|
||||
),
|
||||
LayerSessionSnapshot(
|
||||
name="history",
|
||||
lifecycle_state=LifecycleState.SUSPENDED,
|
||||
runtime_state={"messages": []},
|
||||
),
|
||||
LayerSessionSnapshot(
|
||||
name="llm",
|
||||
lifecycle_state=LifecycleState.SUSPENDED,
|
||||
runtime_state={},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
cleanup_request = AgentBackendRunRequestBuilder().build_cleanup_request(
|
||||
session_snapshot=full_snapshot,
|
||||
composition_layer_specs=persisted_specs,
|
||||
)
|
||||
|
||||
# Drive the real agenton compositor through ``from_config`` + ``_create_run``
|
||||
# the same way the agent backend's RunScheduler does. ``_create_run`` is the
|
||||
# private path that calls ``_validate_session_snapshot``; we use it directly
|
||||
# to keep the test synchronous (no async ``enter()`` lifecycle needed —
|
||||
# validation is the only thing under test).
|
||||
config = {
|
||||
"schema_version": 1,
|
||||
"layers": [
|
||||
{"name": layer.name, "type": layer.type, "deps": dict(layer.deps), "metadata": dict(layer.metadata)}
|
||||
for layer in cleanup_request.composition.layers
|
||||
],
|
||||
}
|
||||
compositor = Compositor.from_config(
|
||||
config,
|
||||
providers=[
|
||||
LayerProvider.from_layer_type(PromptLayer),
|
||||
LayerProvider.from_layer_type(PydanticAIHistoryLayer),
|
||||
],
|
||||
)
|
||||
|
||||
layer_configs = {layer.name: layer.config for layer in cleanup_request.composition.layers}
|
||||
# This is the call that would raise ``ValueError`` if the cleanup snapshot
|
||||
# and composition disagreed on layer names — the exact failure mode the
|
||||
# original ``layers=[]`` cleanup hit.
|
||||
run = compositor._create_run( # type: ignore[reportPrivateUsage]
|
||||
configs=cast(dict[str, object], layer_configs),
|
||||
session_snapshot=cleanup_request.session_snapshot,
|
||||
)
|
||||
assert list(run.slots.keys()) == ["workflow_node_job_prompt", "history"]
|
||||
|
||||
|
||||
def test_cleanup_request_with_mismatched_specs_would_be_rejected_by_agenton():
|
||||
"""Regression sentinel: if a future refactor stops filtering the snapshot,
|
||||
agenton would reject the request — and that rejection is what the runtime
|
||||
fix is preventing. We confirm the validator does fail when given the
|
||||
pre-fix shape so the previous test's success is not a coincidence."""
|
||||
snapshot_with_extra = CompositorSessionSnapshot(
|
||||
layers=[
|
||||
LayerSessionSnapshot(
|
||||
name="history",
|
||||
lifecycle_state=LifecycleState.SUSPENDED,
|
||||
runtime_state={},
|
||||
),
|
||||
LayerSessionSnapshot(
|
||||
name="llm", # extra layer not in composition
|
||||
lifecycle_state=LifecycleState.SUSPENDED,
|
||||
runtime_state={},
|
||||
),
|
||||
]
|
||||
)
|
||||
compositor = Compositor.from_config(
|
||||
{
|
||||
"schema_version": 1,
|
||||
"layers": [{"name": "history", "type": PYDANTIC_AI_HISTORY_LAYER_TYPE_ID, "deps": {}, "metadata": {}}],
|
||||
},
|
||||
providers=[LayerProvider.from_layer_type(PydanticAIHistoryLayer)],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="layer names must match"):
|
||||
compositor._create_run( # type: ignore[reportPrivateUsage]
|
||||
configs={},
|
||||
session_snapshot=snapshot_with_extra,
|
||||
)
|
||||
@ -63,3 +63,25 @@ def test_fake_client_cancel_run_returns_cancelled_status():
|
||||
|
||||
assert cancelled.run_id == "fake-run-1"
|
||||
assert cancelled.status == "cancelled"
|
||||
|
||||
|
||||
def test_fake_client_paused_scenario_returns_paused_status_and_event():
|
||||
"""The paused scenario exists for HITL-style flows; both ``wait_run`` and
|
||||
the event stream must report the pause so consumers can branch on it."""
|
||||
client = FakeAgentBackendRunClient(scenario=FakeAgentBackendScenario.PAUSED)
|
||||
|
||||
status = client.wait_run("fake-run-1")
|
||||
events = list(client.stream_events("fake-run-1"))
|
||||
|
||||
assert status.status == "paused"
|
||||
assert status.error is None
|
||||
assert events[-1].type == "run_paused"
|
||||
assert events[-1].data.reason == "human_input_required"
|
||||
|
||||
|
||||
def test_fake_client_success_wait_run_returns_succeeded_status():
|
||||
"""Covers the default SUCCESS branch of ``wait_run`` directly."""
|
||||
status = FakeAgentBackendRunClient().wait_run("fake-run-1")
|
||||
|
||||
assert status.status == "succeeded"
|
||||
assert status.error is None
|
||||
|
||||
@ -1,15 +1,23 @@
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from agenton.compositor.schemas import LayerSessionSnapshot
|
||||
from agenton.layers import ExitIntent
|
||||
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID
|
||||
from agenton.layers.base import LifecycleState
|
||||
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
|
||||
from agenton_collections.layers.pydantic_ai import PYDANTIC_AI_HISTORY_LAYER_TYPE_ID
|
||||
from dify_agent.layers.dify_plugin import (
|
||||
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
DifyPluginLLMLayerConfig,
|
||||
DifyPluginToolConfig,
|
||||
DifyPluginToolsLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig
|
||||
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID
|
||||
from dify_agent.protocol import (
|
||||
DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
DIFY_AGENT_MODEL_LAYER_ID,
|
||||
DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
CreateRunRequest,
|
||||
@ -26,6 +34,7 @@ from clients.agent_backend import (
|
||||
AgentBackendOutputConfig,
|
||||
AgentBackendRunRequestBuilder,
|
||||
AgentBackendWorkflowNodeRunInput,
|
||||
CleanupLayerSpec,
|
||||
redact_for_agent_backend_log,
|
||||
)
|
||||
|
||||
@ -71,10 +80,11 @@ def test_request_builder_outputs_dify_agent_create_run_request():
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID,
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID,
|
||||
DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
DIFY_AGENT_MODEL_LAYER_ID,
|
||||
DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
]
|
||||
assert request.on_exit.default is ExitIntent.DELETE
|
||||
assert request.on_exit.default is ExitIntent.SUSPEND
|
||||
assert request.idempotency_key == "workflow-run-1:node-execution-1"
|
||||
assert request.metadata == {"workflow_id": "workflow-1", "node_id": "node-1"}
|
||||
|
||||
@ -99,9 +109,10 @@ def test_request_builder_sets_model_and_output_layer_contract_ids():
|
||||
layers = {layer.name: layer for layer in request.composition.layers}
|
||||
|
||||
assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID].type == DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID
|
||||
assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID].config.user_id == "user-1"
|
||||
assert cast(DifyExecutionContextLayerConfig, layers[DIFY_EXECUTION_CONTEXT_LAYER_ID].config).user_id == "user-1"
|
||||
assert layers[DIFY_AGENT_HISTORY_LAYER_ID].type == PYDANTIC_AI_HISTORY_LAYER_TYPE_ID
|
||||
assert layers[DIFY_AGENT_MODEL_LAYER_ID].type == DIFY_PLUGIN_LLM_LAYER_TYPE_ID
|
||||
assert layers[DIFY_AGENT_MODEL_LAYER_ID].config.plugin_id == "langgenius/openai"
|
||||
assert cast(DifyPluginLLMLayerConfig, layers[DIFY_AGENT_MODEL_LAYER_ID].config).plugin_id == "langgenius/openai"
|
||||
assert layers[DIFY_AGENT_MODEL_LAYER_ID].deps == {"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID}
|
||||
assert layers[DIFY_AGENT_OUTPUT_LAYER_ID].type == DIFY_OUTPUT_LAYER_TYPE_ID
|
||||
|
||||
@ -130,16 +141,92 @@ def test_request_builder_adds_dify_plugin_tools_layer_when_configured():
|
||||
|
||||
assert layers[DIFY_PLUGIN_TOOLS_LAYER_ID].type == DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID
|
||||
assert layers[DIFY_PLUGIN_TOOLS_LAYER_ID].deps == {"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID}
|
||||
assert layers[DIFY_PLUGIN_TOOLS_LAYER_ID].config.tools[0].tool_name == "current_time"
|
||||
tools_config = cast(DifyPluginToolsLayerConfig, layers[DIFY_PLUGIN_TOOLS_LAYER_ID].config)
|
||||
assert tools_config.tools[0].tool_name == "current_time"
|
||||
|
||||
|
||||
def test_request_builder_can_suspend_on_exit_for_resume_or_babysit_paths():
|
||||
def test_request_builder_can_delete_on_exit_for_cleanup_paths():
|
||||
run_input = _run_input()
|
||||
run_input.suspend_on_exit = True
|
||||
run_input.suspend_on_exit = False
|
||||
|
||||
request = AgentBackendRunRequestBuilder().build_for_workflow_node(run_input)
|
||||
|
||||
assert request.on_exit.default is ExitIntent.SUSPEND
|
||||
assert request.on_exit.default is ExitIntent.DELETE
|
||||
|
||||
|
||||
def test_request_builder_builds_cleanup_request_replays_persisted_layer_specs():
|
||||
"""The cleanup request must replay the persisted (non-plugin) layer specs
|
||||
and filter the snapshot to match so the agenton compositor's
|
||||
snapshot-vs-composition name-order validator passes."""
|
||||
session_snapshot = CompositorSessionSnapshot(
|
||||
layers=[
|
||||
LayerSessionSnapshot(name="history", lifecycle_state=LifecycleState.SUSPENDED, runtime_state={"k": 1}),
|
||||
LayerSessionSnapshot(name="llm", lifecycle_state=LifecycleState.SUSPENDED, runtime_state={}),
|
||||
]
|
||||
)
|
||||
specs = [CleanupLayerSpec(name="history", type="pydantic_ai.history")]
|
||||
|
||||
request = AgentBackendRunRequestBuilder().build_cleanup_request(
|
||||
session_snapshot=session_snapshot,
|
||||
composition_layer_specs=specs,
|
||||
idempotency_key="run-1:node-1:binding-1:agent-session-cleanup",
|
||||
metadata={"workflow_run_id": "run-1"},
|
||||
)
|
||||
|
||||
assert [layer.name for layer in request.composition.layers] == ["history"]
|
||||
assert request.session_snapshot is not None
|
||||
assert [layer.name for layer in request.session_snapshot.layers] == ["history"]
|
||||
assert request.on_exit.default is ExitIntent.DELETE
|
||||
assert request.idempotency_key == "run-1:node-1:binding-1:agent-session-cleanup"
|
||||
assert request.metadata["agent_backend_lifecycle"] == "session_cleanup"
|
||||
|
||||
|
||||
def test_request_builder_rejects_empty_composition_layer_specs():
|
||||
"""Empty specs would put us back in the original ``layers=[]`` trap that
|
||||
fails on agenton's snapshot-vs-composition validation."""
|
||||
with pytest.raises(ValueError, match="composition_layer_specs"):
|
||||
AgentBackendRunRequestBuilder().build_cleanup_request(
|
||||
session_snapshot=CompositorSessionSnapshot(layers=[]),
|
||||
composition_layer_specs=[],
|
||||
)
|
||||
|
||||
|
||||
def test_extract_cleanup_layer_specs_drops_plugin_layers_keeps_configs():
|
||||
from dify_agent.protocol import RunComposition, RunLayerSpec
|
||||
|
||||
from clients.agent_backend import extract_cleanup_layer_specs
|
||||
|
||||
composition = RunComposition(
|
||||
layers=[
|
||||
RunLayerSpec(
|
||||
name="agent_soul_prompt",
|
||||
type="plain.prompt",
|
||||
config=PromptLayerConfig(prefix="hello"),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name="llm",
|
||||
type="dify.plugin.llm",
|
||||
config=None, # protocol allows None; the redacted config is what matters
|
||||
),
|
||||
RunLayerSpec(
|
||||
name="tools",
|
||||
type="dify.plugin.tools",
|
||||
),
|
||||
RunLayerSpec(
|
||||
name="history",
|
||||
type="pydantic_ai.history",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
specs = extract_cleanup_layer_specs(composition)
|
||||
|
||||
assert [spec.name for spec in specs] == ["agent_soul_prompt", "history"]
|
||||
# Non-plugin configs are dumped as JSON-compatible dicts so the persisted
|
||||
# row can be replayed without holding live pydantic instances.
|
||||
soul_config = specs[0].config
|
||||
assert isinstance(soul_config, dict)
|
||||
assert soul_config.get("prefix") == "hello"
|
||||
|
||||
|
||||
def test_request_builder_rejects_blank_prompts():
|
||||
@ -159,6 +246,6 @@ def test_request_builder_rejects_blank_prompts():
|
||||
def test_redact_for_agent_backend_log_hides_credentials():
|
||||
request = AgentBackendRunRequestBuilder().build_for_workflow_node(_run_input())
|
||||
|
||||
redacted = redact_for_agent_backend_log(request)
|
||||
redacted = cast(dict[str, Any], redact_for_agent_backend_log(request))
|
||||
|
||||
assert redacted["composition"]["layers"][4]["config"]["credentials"] == "[REDACTED]"
|
||||
assert redacted["composition"]["layers"][5]["config"]["credentials"] == "[REDACTED]"
|
||||
|
||||
@ -1,66 +1,73 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE, _resolve_app_authz_strategy
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
SurfaceCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
)
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from controllers.openapi.auth.composition import account_pipeline, auth_router, external_sso_pipeline
|
||||
from controllers.openapi.auth.flow import When
|
||||
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
|
||||
from libs.oauth_bearer import TokenType
|
||||
|
||||
|
||||
def test_pipeline_is_composed():
|
||||
assert isinstance(OAUTH_BEARER_PIPELINE, Pipeline)
|
||||
def test_account_pipeline_is_auth_pipeline():
|
||||
assert isinstance(account_pipeline, AuthPipeline)
|
||||
|
||||
|
||||
def test_pipeline_step_order():
|
||||
"""BearerCheck → SurfaceCheck → ScopeCheck → AppResolver →
|
||||
WorkspaceMembershipCheck → AppAuthzCheck → CallerMount.
|
||||
SurfaceCheck enforces the dfoa_/dfoe_ surface split + emits
|
||||
`openapi.wrong_surface_denied`. Rate-limit is enforced inside
|
||||
`BearerAuthenticator.authenticate`, not as a separate pipeline step."""
|
||||
steps = OAUTH_BEARER_PIPELINE._steps
|
||||
assert isinstance(steps[0], BearerCheck)
|
||||
assert isinstance(steps[1], SurfaceCheck)
|
||||
assert isinstance(steps[2], ScopeCheck)
|
||||
assert isinstance(steps[3], AppResolver)
|
||||
assert isinstance(steps[4], WorkspaceMembershipCheck)
|
||||
assert isinstance(steps[5], AppAuthzCheck)
|
||||
assert isinstance(steps[6], CallerMount)
|
||||
def test_external_sso_pipeline_is_auth_pipeline():
|
||||
assert isinstance(external_sso_pipeline, AuthPipeline)
|
||||
|
||||
|
||||
def test_pipeline_surface_check_accepts_account_only():
|
||||
"""Current pipeline serves /apps/<id>/run — account surface only."""
|
||||
surface = OAUTH_BEARER_PIPELINE._steps[1]
|
||||
assert isinstance(surface, SurfaceCheck)
|
||||
assert surface._accepted == frozenset({SubjectType.ACCOUNT})
|
||||
def test_auth_router_is_pipeline_router():
|
||||
assert isinstance(auth_router, PipelineRouter)
|
||||
|
||||
|
||||
def test_caller_mount_has_both_mounters():
|
||||
cm = OAUTH_BEARER_PIPELINE._steps[6]
|
||||
kinds = {type(m) for m in cm._mounters}
|
||||
assert AccountMounter in kinds
|
||||
assert EndUserMounter in kinds
|
||||
def test_account_pipeline_prepare_has_four_entries():
|
||||
assert len(account_pipeline._prepare) == 4
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.composition.FeatureService")
|
||||
def test_strategy_resolver_picks_acl_when_enabled(fs):
|
||||
fs.get_system_features.return_value.webapp_auth.enabled = True
|
||||
assert isinstance(_resolve_app_authz_strategy(), AclStrategy)
|
||||
def test_account_auth_list_has_five_entries():
|
||||
assert len(account_pipeline._auth) == 5
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.composition.FeatureService")
|
||||
def test_strategy_resolver_picks_membership_when_disabled(fs):
|
||||
fs.get_system_features.return_value.webapp_auth.enabled = False
|
||||
assert isinstance(_resolve_app_authz_strategy(), MembershipStrategy)
|
||||
def test_external_sso_pipeline_prepare_has_four_entries():
|
||||
assert len(external_sso_pipeline._prepare) == 4
|
||||
|
||||
|
||||
def test_external_sso_auth_list_has_three_entries():
|
||||
assert len(external_sso_pipeline._auth) == 3
|
||||
|
||||
|
||||
def test_account_pipeline_has_unconditional_load_account():
|
||||
non_when = [s for s in account_pipeline._prepare if not isinstance(s, When)]
|
||||
assert len(non_when) == 1
|
||||
|
||||
|
||||
def test_external_sso_pipeline_all_prepare_entries_are_when():
|
||||
assert all(isinstance(s, When) for s in external_sso_pipeline._prepare)
|
||||
|
||||
|
||||
def test_first_auth_entry_is_check_scope_in_both_pipelines():
|
||||
assert not isinstance(account_pipeline._auth[0], When)
|
||||
assert not isinstance(external_sso_pipeline._auth[0], When)
|
||||
|
||||
|
||||
def test_remaining_auth_entries_are_when_for_account():
|
||||
assert all(isinstance(s, When) for s in account_pipeline._auth[1:])
|
||||
|
||||
|
||||
def test_remaining_auth_entries_are_when_for_external_sso():
|
||||
assert all(isinstance(s, When) for s in external_sso_pipeline._auth[1:])
|
||||
|
||||
|
||||
def test_router_routes_contain_both_token_types():
|
||||
assert TokenType.OAUTH_ACCOUNT in auth_router._routes
|
||||
assert TokenType.OAUTH_EXTERNAL_SSO in auth_router._routes
|
||||
|
||||
|
||||
def test_external_sso_route_has_ee_required_edition():
|
||||
route = auth_router._routes[TokenType.OAUTH_EXTERNAL_SSO]
|
||||
assert isinstance(route, PipelineRoute)
|
||||
from controllers.openapi.auth.data import Edition
|
||||
|
||||
assert route.required_edition == frozenset({Edition.EE})
|
||||
|
||||
|
||||
def test_account_route_has_no_required_edition():
|
||||
route = auth_router._routes[TokenType.OAUTH_ACCOUNT]
|
||||
assert isinstance(route, PipelineRoute)
|
||||
assert route.required_edition is None
|
||||
|
||||
143
api/tests/unit_tests/controllers/openapi/auth/test_conditions.py
Normal file
143
api/tests/unit_tests/controllers/openapi/auth/test_conditions.py
Normal file
@ -0,0 +1,143 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from controllers.openapi.auth.conditions import (
|
||||
EDITION_CE,
|
||||
EDITION_EE,
|
||||
EDITION_SAAS,
|
||||
LOADED_APP_IS_PRIVATE,
|
||||
PATH_HAS_APP_ID,
|
||||
TOKEN_IS_OAUTH_ACCOUNT,
|
||||
TOKEN_IS_OAUTH_EXTERNAL_SSO,
|
||||
WEBAPP_AUTH_ENABLED,
|
||||
Cond,
|
||||
config_cond,
|
||||
data_cond,
|
||||
request_cond,
|
||||
)
|
||||
from controllers.openapi.auth.data import AuthData, Edition, RequestContext
|
||||
from libs.oauth_bearer import TokenType
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
|
||||
def _ctx(token_type=TokenType.OAUTH_ACCOUNT, path_params=None):
|
||||
return RequestContext(
|
||||
token_type=token_type,
|
||||
path_params=path_params or {},
|
||||
)
|
||||
|
||||
|
||||
def _data(**kwargs):
|
||||
defaults: dict = {"token_type": TokenType.OAUTH_ACCOUNT, "token_hash": "x", "scopes": frozenset()}
|
||||
defaults.update(kwargs)
|
||||
return AuthData(**defaults)
|
||||
|
||||
|
||||
def test_and_both_true():
|
||||
a = Cond(lambda ctx, _: True)
|
||||
b = Cond(lambda ctx, _: True)
|
||||
assert (a & b)(_ctx()) is True
|
||||
|
||||
|
||||
def test_and_one_false():
|
||||
a = Cond(lambda ctx, _: True)
|
||||
b = Cond(lambda ctx, _: False)
|
||||
assert (a & b)(_ctx()) is False
|
||||
|
||||
|
||||
def test_or_one_true():
|
||||
a = Cond(lambda ctx, _: False)
|
||||
b = Cond(lambda ctx, _: True)
|
||||
assert (a | b)(_ctx()) is True
|
||||
|
||||
|
||||
def test_or_both_false():
|
||||
a = Cond(lambda ctx, _: False)
|
||||
b = Cond(lambda ctx, _: False)
|
||||
assert (a | b)(_ctx()) is False
|
||||
|
||||
|
||||
def test_invert():
|
||||
a = Cond(lambda ctx, _: True)
|
||||
assert (~a)(_ctx()) is False
|
||||
|
||||
|
||||
def test_chain_and_or():
|
||||
always_true = Cond(lambda ctx, _: True)
|
||||
always_false = Cond(lambda ctx, _: False)
|
||||
assert ((always_true | always_false) & always_true)(_ctx()) is True
|
||||
|
||||
|
||||
def test_request_cond_ignores_data():
|
||||
c = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_ACCOUNT)
|
||||
assert c(_ctx(TokenType.OAUTH_ACCOUNT)) is True
|
||||
assert c(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is False
|
||||
|
||||
|
||||
def test_data_cond_returns_false_when_data_none():
|
||||
c = data_cond(lambda data: True)
|
||||
assert c(_ctx(), None) is False
|
||||
|
||||
|
||||
def test_data_cond_evaluates_when_data_present():
|
||||
c = data_cond(lambda data: data.token_hash == "secret")
|
||||
assert c(_ctx(), _data(token_hash="secret")) is True
|
||||
assert c(_ctx(), _data(token_hash="other")) is False
|
||||
|
||||
|
||||
def test_config_cond_ignores_ctx_and_data():
|
||||
c = config_cond(lambda: True)
|
||||
assert c(_ctx()) is True
|
||||
c2 = config_cond(lambda: False)
|
||||
assert c2(_ctx(), _data()) is False
|
||||
|
||||
|
||||
def test_token_is_oauth_account():
|
||||
assert TOKEN_IS_OAUTH_ACCOUNT(_ctx(TokenType.OAUTH_ACCOUNT)) is True
|
||||
assert TOKEN_IS_OAUTH_ACCOUNT(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is False
|
||||
|
||||
|
||||
def test_token_is_oauth_external_sso():
|
||||
assert TOKEN_IS_OAUTH_EXTERNAL_SSO(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is True
|
||||
|
||||
|
||||
def test_path_has_app_id_true():
|
||||
assert PATH_HAS_APP_ID(_ctx(path_params={"app_id": "abc"})) is True
|
||||
|
||||
|
||||
def test_path_has_app_id_false():
|
||||
assert PATH_HAS_APP_ID(_ctx(path_params={})) is False
|
||||
|
||||
|
||||
def test_edition_ce():
|
||||
with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.CE):
|
||||
assert EDITION_CE(_ctx()) is True
|
||||
assert EDITION_EE(_ctx()) is False
|
||||
assert EDITION_SAAS(_ctx()) is False
|
||||
|
||||
|
||||
def test_edition_ee():
|
||||
with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.EE):
|
||||
assert EDITION_EE(_ctx()) is True
|
||||
assert EDITION_CE(_ctx()) is False
|
||||
|
||||
|
||||
def test_edition_saas():
|
||||
with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.SAAS):
|
||||
assert EDITION_SAAS(_ctx()) is True
|
||||
|
||||
|
||||
def test_webapp_auth_enabled():
|
||||
mock_features = MagicMock()
|
||||
mock_features.webapp_auth.enabled = True
|
||||
with patch("controllers.openapi.auth.conditions.FeatureService.get_system_features", return_value=mock_features):
|
||||
assert WEBAPP_AUTH_ENABLED(_ctx()) is True
|
||||
|
||||
|
||||
def test_loaded_app_is_private():
|
||||
data_private = _data(app_access_mode=WebAppAccessMode.PRIVATE)
|
||||
data_public = _data(app_access_mode=WebAppAccessMode.PUBLIC)
|
||||
data_none = _data(app_access_mode=None)
|
||||
assert LOADED_APP_IS_PRIVATE(_ctx(), data_private) is True
|
||||
assert LOADED_APP_IS_PRIVATE(_ctx(), data_public) is False
|
||||
assert LOADED_APP_IS_PRIVATE(_ctx(), data_none) is False
|
||||
assert LOADED_APP_IS_PRIVATE(_ctx(), None) is False
|
||||
@ -1,21 +0,0 @@
|
||||
from controllers.openapi.auth.context import Context
|
||||
|
||||
|
||||
def test_context_starts_unpopulated():
|
||||
ctx = Context(required_scope="apps:run")
|
||||
assert ctx.bearer_token is None
|
||||
assert ctx.path_params == {}
|
||||
assert ctx.subject_type is None
|
||||
assert ctx.subject_email is None
|
||||
assert ctx.account_id is None
|
||||
assert ctx.scopes == frozenset()
|
||||
assert ctx.app is None
|
||||
assert ctx.tenant is None
|
||||
assert ctx.caller is None
|
||||
assert ctx.caller_kind is None
|
||||
|
||||
|
||||
def test_context_fields_are_mutable():
|
||||
ctx = Context(required_scope="apps:run")
|
||||
ctx.scopes = frozenset({"full"})
|
||||
assert "full" in ctx.scopes
|
||||
117
api/tests/unit_tests/controllers/openapi/auth/test_data.py
Normal file
117
api/tests/unit_tests/controllers/openapi/auth/test_data.py
Normal file
@ -0,0 +1,117 @@
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.openapi.auth.data import (
|
||||
AuthData,
|
||||
Edition,
|
||||
ExternalIdentity,
|
||||
RequestContext,
|
||||
current_edition,
|
||||
)
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
|
||||
|
||||
def test_current_edition_saas():
|
||||
with patch("controllers.openapi.auth.data.dify_config") as cfg:
|
||||
cfg.EDITION = "CLOUD"
|
||||
cfg.ENTERPRISE_ENABLED = True
|
||||
assert current_edition() == Edition.SAAS
|
||||
|
||||
|
||||
def test_current_edition_ee():
|
||||
with patch("controllers.openapi.auth.data.dify_config") as cfg:
|
||||
cfg.EDITION = "SELF_HOSTED"
|
||||
cfg.ENTERPRISE_ENABLED = True
|
||||
assert current_edition() == Edition.EE
|
||||
|
||||
|
||||
def test_current_edition_ce():
|
||||
with patch("controllers.openapi.auth.data.dify_config") as cfg:
|
||||
cfg.EDITION = "SELF_HOSTED"
|
||||
cfg.ENTERPRISE_ENABLED = False
|
||||
assert current_edition() == Edition.CE
|
||||
|
||||
|
||||
def test_external_identity_frozen():
|
||||
ei = ExternalIdentity(email="a@b.com", issuer="idp")
|
||||
with pytest.raises(ValidationError):
|
||||
ei.email = "other@b.com" # type: ignore[misc]
|
||||
|
||||
|
||||
def test_external_identity_issuer_optional():
|
||||
ei = ExternalIdentity(email="a@b.com")
|
||||
assert ei.issuer is None
|
||||
|
||||
|
||||
def test_request_context_frozen():
|
||||
ctx = RequestContext(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
path_params={"app_id": "123"},
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
ctx.token_type = TokenType.OAUTH_EXTERNAL_SSO # type: ignore[misc]
|
||||
|
||||
|
||||
def test_request_context_scope_optional():
|
||||
ctx = RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={})
|
||||
assert ctx.scope is None
|
||||
|
||||
|
||||
def test_auth_data_is_mutable():
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
token_hash="abc",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
)
|
||||
data.token_type = TokenType.OAUTH_EXTERNAL_SSO
|
||||
assert data.token_type == TokenType.OAUTH_EXTERNAL_SSO
|
||||
|
||||
|
||||
def test_auth_data_path_params_defaults_empty():
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
token_hash="abc",
|
||||
scopes=frozenset(),
|
||||
)
|
||||
assert data.path_params == {}
|
||||
|
||||
|
||||
def test_auth_data_account_id_optional():
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_EXTERNAL_SSO,
|
||||
token_hash="abc",
|
||||
scopes=frozenset({Scope.APPS_RUN}),
|
||||
external_identity=ExternalIdentity(email="u@sso.com"),
|
||||
)
|
||||
assert data.account_id is None
|
||||
|
||||
|
||||
def test_auth_data_external_identity_none_for_account():
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
account_id=uuid.uuid4(),
|
||||
token_hash="abc",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
)
|
||||
assert data.external_identity is None
|
||||
|
||||
|
||||
def test_auth_data_tenants_default_empty():
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
token_hash="abc",
|
||||
scopes=frozenset(),
|
||||
)
|
||||
assert data.tenants == {}
|
||||
|
||||
|
||||
def test_auth_data_token_id_optional():
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
token_hash="abc",
|
||||
scopes=frozenset(),
|
||||
)
|
||||
assert data.token_id is None
|
||||
42
api/tests/unit_tests/controllers/openapi/auth/test_flow.py
Normal file
42
api/tests/unit_tests/controllers/openapi/auth/test_flow.py
Normal file
@ -0,0 +1,42 @@
|
||||
import inspect
|
||||
|
||||
from controllers.openapi.auth.conditions import Cond
|
||||
from controllers.openapi.auth.data import AuthData, RequestContext
|
||||
from controllers.openapi.auth.flow import When
|
||||
from libs.oauth_bearer import TokenType
|
||||
|
||||
|
||||
def _ctx():
|
||||
return RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={})
|
||||
|
||||
|
||||
def _data():
|
||||
return AuthData(token_type=TokenType.OAUTH_ACCOUNT, token_hash="x", scopes=frozenset())
|
||||
|
||||
|
||||
def test_applies_returns_true_when_condition_true():
|
||||
w = When(Cond(lambda ctx, _: True), then=lambda b: None)
|
||||
assert w.applies(_ctx()) is True
|
||||
|
||||
|
||||
def test_applies_returns_false_when_condition_false():
|
||||
w = When(Cond(lambda ctx, _: False), then=lambda b: None)
|
||||
assert w.applies(_ctx()) is False
|
||||
|
||||
|
||||
def test_applies_with_data():
|
||||
w = When(Cond(lambda ctx, data: data is not None), then=lambda b: None)
|
||||
assert w.applies(_ctx(), _data()) is True
|
||||
assert w.applies(_ctx(), None) is False
|
||||
|
||||
|
||||
def test_call_invokes_step():
|
||||
calls = []
|
||||
w = When(Cond(lambda ctx, _: True), then=lambda arg: calls.append(arg))
|
||||
w("payload")
|
||||
assert calls == ["payload"]
|
||||
|
||||
|
||||
def test_then_is_keyword_only():
|
||||
sig = inspect.signature(When.__init__)
|
||||
assert sig.parameters["then"].kind.name == "KEYWORD_ONLY"
|
||||
@ -1,59 +1,269 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.data import AuthData, Edition
|
||||
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
|
||||
|
||||
def test_run_invokes_each_step_in_order():
|
||||
calls = []
|
||||
|
||||
class S:
|
||||
def __init__(self, tag):
|
||||
self.tag = tag
|
||||
|
||||
def __call__(self, ctx):
|
||||
calls.append(self.tag)
|
||||
|
||||
Pipeline(S("a"), S("b"), S("c")).run(Context(required_scope="x"))
|
||||
assert calls == ["a", "b", "c"]
|
||||
def _make_identity(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
account_id=None,
|
||||
scopes=None,
|
||||
token_hash="testhash",
|
||||
subject_email=None,
|
||||
subject_issuer=None,
|
||||
verified_tenants=None,
|
||||
token_id=None,
|
||||
):
|
||||
identity = MagicMock()
|
||||
identity.token_type = token_type
|
||||
identity.account_id = account_id or uuid.uuid4()
|
||||
identity.scopes = scopes or frozenset({Scope.FULL})
|
||||
identity.token_hash = token_hash
|
||||
identity.subject_email = subject_email
|
||||
identity.subject_issuer = subject_issuer
|
||||
identity.verified_tenants = verified_tenants or {}
|
||||
identity.token_id = token_id or uuid.uuid4()
|
||||
return identity
|
||||
|
||||
|
||||
def test_run_short_circuits_on_raise():
|
||||
calls = []
|
||||
|
||||
class Boom:
|
||||
def __call__(self, ctx):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
class Tail:
|
||||
def __call__(self, ctx):
|
||||
calls.append("ran")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
Pipeline(Boom(), Tail()).run(Context(required_scope="x"))
|
||||
assert calls == []
|
||||
@pytest.fixture
|
||||
def app():
|
||||
return Flask(__name__)
|
||||
|
||||
|
||||
def test_guard_decorator_runs_pipeline_and_unpacks_handler_kwargs():
|
||||
seen = {}
|
||||
def _make_router(token_type=TokenType.OAUTH_ACCOUNT, prepare=None, auth=None):
|
||||
pipeline = AuthPipeline(prepare=prepare or [], auth=auth or [])
|
||||
return PipelineRouter({token_type: PipelineRoute(pipeline)})
|
||||
|
||||
class FakeStep:
|
||||
def __call__(self, ctx):
|
||||
ctx.app = "APP"
|
||||
ctx.caller = "CALLER"
|
||||
ctx.caller_kind = "account"
|
||||
|
||||
pipeline = Pipeline(FakeStep())
|
||||
def _fake_identity():
|
||||
return _make_identity()
|
||||
|
||||
@pipeline.guard(scope="apps:run")
|
||||
def handler(app_model, caller, caller_kind):
|
||||
seen["app_model"] = app_model
|
||||
seen["caller"] = caller
|
||||
seen["caller_kind"] = caller_kind
|
||||
return "ok"
|
||||
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/x", method="POST"):
|
||||
assert handler() == "ok"
|
||||
assert seen == {"app_model": "APP", "caller": "CALLER", "caller_kind": "account"}
|
||||
# --- PipelineRouter.guard ---
|
||||
|
||||
|
||||
def test_guard_passes_auth_data_to_view(app):
|
||||
router = _make_router()
|
||||
received = {}
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with (
|
||||
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
|
||||
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
|
||||
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
|
||||
):
|
||||
mock_auth.return_value.authenticate.return_value = _fake_identity()
|
||||
|
||||
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def view(*, auth_data):
|
||||
received["data"] = auth_data
|
||||
|
||||
view()
|
||||
|
||||
assert isinstance(received["data"], AuthData)
|
||||
|
||||
|
||||
def test_guard_edition_gate_returns_404(app):
|
||||
router = _make_router()
|
||||
|
||||
with app.test_request_context("/test"):
|
||||
with patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE):
|
||||
|
||||
@router.guard(scope=Scope.FULL, edition=frozenset({Edition.EE}))
|
||||
def view(*, auth_data):
|
||||
pass
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
view()
|
||||
|
||||
|
||||
def test_guard_token_type_gate_returns_403(app):
|
||||
router = _make_router()
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with (
|
||||
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
|
||||
patch("controllers.openapi.auth.pipeline.emit_wrong_surface"),
|
||||
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE),
|
||||
):
|
||||
identity = _fake_identity()
|
||||
identity.token_type = TokenType.OAUTH_EXTERNAL_SSO
|
||||
mock_auth.return_value.authenticate.return_value = identity
|
||||
|
||||
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def view(*, auth_data):
|
||||
pass
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
view()
|
||||
|
||||
|
||||
def test_guard_unregistered_token_type_returns_403(app):
|
||||
router = _make_router(token_type=TokenType.OAUTH_ACCOUNT)
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with (
|
||||
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
|
||||
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE),
|
||||
):
|
||||
identity = _fake_identity()
|
||||
identity.token_type = TokenType.OAUTH_EXTERNAL_SSO
|
||||
mock_auth.return_value.authenticate.return_value = identity
|
||||
|
||||
@router.guard(scope=Scope.FULL)
|
||||
def view(*, auth_data):
|
||||
pass
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
view()
|
||||
|
||||
|
||||
def test_guard_no_bearer_returns_401(app):
|
||||
router = _make_router()
|
||||
|
||||
with app.test_request_context("/test"):
|
||||
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value=None):
|
||||
|
||||
@router.guard(scope=Scope.FULL)
|
||||
def view(*, auth_data):
|
||||
pass
|
||||
|
||||
with pytest.raises(Unauthorized):
|
||||
view()
|
||||
|
||||
|
||||
def test_guard_runs_prepare_steps_in_order(app):
|
||||
order = []
|
||||
|
||||
def p1(b):
|
||||
order.append("p1")
|
||||
|
||||
def p2(b):
|
||||
order.append("p2")
|
||||
|
||||
router = _make_router(prepare=[p1, p2])
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with (
|
||||
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
|
||||
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
|
||||
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
|
||||
):
|
||||
mock_auth.return_value.authenticate.return_value = _fake_identity()
|
||||
|
||||
@router.guard(scope=Scope.FULL)
|
||||
def view(*, auth_data):
|
||||
pass
|
||||
|
||||
view()
|
||||
|
||||
assert order == ["p1", "p2"]
|
||||
|
||||
|
||||
def test_guard_resets_auth_ctx_on_exception(app):
|
||||
router = _make_router()
|
||||
reset_called = []
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with (
|
||||
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
|
||||
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value="tok"),
|
||||
patch("controllers.openapi.auth.pipeline.reset_auth_ctx", side_effect=lambda t: reset_called.append(t)),
|
||||
):
|
||||
mock_auth.return_value.authenticate.return_value = _fake_identity()
|
||||
|
||||
@router.guard(scope=Scope.FULL)
|
||||
def view(*, auth_data):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
view()
|
||||
|
||||
assert reset_called == ["tok"]
|
||||
|
||||
|
||||
def test_router_rejects_token_type_on_wrong_edition(app):
|
||||
pipeline = AuthPipeline(prepare=[], auth=[])
|
||||
route = PipelineRoute(pipeline, required_edition=frozenset({Edition.EE}))
|
||||
router = PipelineRouter({TokenType.OAUTH_EXTERNAL_SSO: route})
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with (
|
||||
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
|
||||
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE),
|
||||
):
|
||||
identity = _make_identity(token_type=TokenType.OAUTH_EXTERNAL_SSO)
|
||||
mock_auth.return_value.authenticate.return_value = identity
|
||||
|
||||
@router.guard(scope=Scope.APPS_RUN)
|
||||
def view(*, auth_data):
|
||||
pass
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
view()
|
||||
|
||||
|
||||
def test_guard_populates_external_identity_from_subject_email(app):
|
||||
from controllers.openapi.auth.data import ExternalIdentity
|
||||
|
||||
router = _make_router(token_type=TokenType.OAUTH_EXTERNAL_SSO)
|
||||
received = {}
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with (
|
||||
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
|
||||
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
|
||||
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
|
||||
):
|
||||
identity = _make_identity(
|
||||
token_type=TokenType.OAUTH_EXTERNAL_SSO,
|
||||
subject_email="user@sso.com",
|
||||
subject_issuer="https://idp.example.com",
|
||||
)
|
||||
mock_auth.return_value.authenticate.return_value = identity
|
||||
|
||||
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}))
|
||||
def view(*, auth_data):
|
||||
received["data"] = auth_data
|
||||
|
||||
view()
|
||||
|
||||
assert isinstance(received["data"].external_identity, ExternalIdentity)
|
||||
assert received["data"].external_identity.email == "user@sso.com"
|
||||
assert received["data"].external_identity.issuer == "https://idp.example.com"
|
||||
|
||||
|
||||
def test_guard_no_external_identity_when_subject_email_absent(app):
|
||||
router = _make_router()
|
||||
received = {}
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with (
|
||||
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
|
||||
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
|
||||
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
|
||||
):
|
||||
mock_auth.return_value.authenticate.return_value = _make_identity(subject_email=None)
|
||||
|
||||
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def view(*, auth_data):
|
||||
received["data"] = auth_data
|
||||
|
||||
view()
|
||||
|
||||
assert received["data"].external_identity is None
|
||||
|
||||
183
api/tests/unit_tests/controllers/openapi/auth/test_prepare.py
Normal file
183
api/tests/unit_tests/controllers/openapi/auth/test_prepare.py
Normal file
@ -0,0 +1,183 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData, ExternalIdentity
|
||||
from controllers.openapi.auth.prepare import (
|
||||
load_account,
|
||||
load_app,
|
||||
load_app_access_mode,
|
||||
load_tenant,
|
||||
resolve_external_user,
|
||||
)
|
||||
from libs.oauth_bearer import TokenType
|
||||
|
||||
|
||||
def _make_auth_data(**kwargs) -> AuthData:
|
||||
mock_fields = {k: kwargs.pop(k) for k in ("app", "tenant", "caller") if k in kwargs}
|
||||
data = AuthData(
|
||||
token_type=kwargs.pop("token_type", TokenType.OAUTH_ACCOUNT),
|
||||
token_hash=kwargs.pop("token_hash", "testhash"),
|
||||
scopes=kwargs.pop("scopes", frozenset()),
|
||||
**kwargs,
|
||||
)
|
||||
for k, v in mock_fields.items():
|
||||
setattr(data, k, v)
|
||||
return data
|
||||
|
||||
|
||||
def test_load_app_writes_app_to_data():
|
||||
app = MagicMock()
|
||||
app.status = "normal"
|
||||
app.enable_api = True
|
||||
data = _make_auth_data(path_params={"app_id": "abc"})
|
||||
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app):
|
||||
load_app(data)
|
||||
assert data.app is app
|
||||
|
||||
|
||||
def test_load_app_raises_not_found_when_missing():
|
||||
data = _make_auth_data(path_params={"app_id": "missing"})
|
||||
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=None):
|
||||
with pytest.raises(NotFound):
|
||||
load_app(data)
|
||||
|
||||
|
||||
def test_load_app_raises_not_found_when_not_normal():
|
||||
app = MagicMock()
|
||||
app.status = "archived"
|
||||
data = _make_auth_data(path_params={"app_id": "abc"})
|
||||
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app):
|
||||
with pytest.raises(NotFound):
|
||||
load_app(data)
|
||||
|
||||
|
||||
def test_load_app_raises_forbidden_when_api_disabled():
|
||||
app = MagicMock()
|
||||
app.status = "normal"
|
||||
app.enable_api = False
|
||||
data = _make_auth_data(path_params={"app_id": "abc"})
|
||||
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app):
|
||||
with pytest.raises(Forbidden):
|
||||
load_app(data)
|
||||
|
||||
|
||||
def test_load_tenant_writes_tenant():
|
||||
app = MagicMock()
|
||||
app.tenant_id = uuid.uuid4()
|
||||
tenant = MagicMock()
|
||||
tenant.status = "normal"
|
||||
data = _make_auth_data(app=app)
|
||||
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant):
|
||||
load_tenant(data)
|
||||
assert data.tenant is tenant
|
||||
|
||||
|
||||
def test_load_tenant_raises_forbidden_when_archived():
|
||||
from models.account import TenantStatus
|
||||
|
||||
app = MagicMock()
|
||||
app.tenant_id = uuid.uuid4()
|
||||
tenant = MagicMock()
|
||||
tenant.status = TenantStatus.ARCHIVE
|
||||
data = _make_auth_data(app=app)
|
||||
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant):
|
||||
with pytest.raises(Forbidden):
|
||||
load_tenant(data)
|
||||
|
||||
|
||||
def test_load_tenant_raises_forbidden_when_missing():
|
||||
app = MagicMock()
|
||||
app.tenant_id = uuid.uuid4()
|
||||
data = _make_auth_data(app=app)
|
||||
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=None):
|
||||
with pytest.raises(Forbidden):
|
||||
load_tenant(data)
|
||||
|
||||
|
||||
def test_load_tenant_raises_500_when_app_not_loaded():
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
data = _make_auth_data()
|
||||
with pytest.raises(InternalServerError):
|
||||
load_tenant(data)
|
||||
|
||||
|
||||
def test_load_account_writes_caller():
|
||||
account = MagicMock()
|
||||
account_id = uuid.uuid4()
|
||||
data = _make_auth_data(account_id=account_id)
|
||||
with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=account):
|
||||
load_account(data)
|
||||
assert data.caller is account
|
||||
assert data.caller_kind == "account"
|
||||
|
||||
|
||||
def test_load_account_sets_current_tenant_when_tenant_present():
|
||||
account = MagicMock()
|
||||
tenant = MagicMock()
|
||||
data = _make_auth_data(account_id=uuid.uuid4(), tenant=tenant)
|
||||
with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=account):
|
||||
load_account(data)
|
||||
assert account.current_tenant is tenant
|
||||
|
||||
|
||||
def test_load_account_raises_unauthorized_when_not_found():
|
||||
data = _make_auth_data(account_id=uuid.uuid4())
|
||||
with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=None):
|
||||
with pytest.raises(Unauthorized):
|
||||
load_account(data)
|
||||
|
||||
|
||||
def test_resolve_external_user_writes_caller():
|
||||
tenant = MagicMock()
|
||||
app = MagicMock()
|
||||
end_user = MagicMock()
|
||||
ext = ExternalIdentity(email="user@sso.com")
|
||||
data = _make_auth_data(tenant=tenant, app=app, external_identity=ext)
|
||||
with patch("controllers.openapi.auth.prepare.EndUserService.get_or_create_end_user_by_type", return_value=end_user):
|
||||
resolve_external_user(data)
|
||||
assert data.caller is end_user
|
||||
assert data.caller_kind == "end_user"
|
||||
|
||||
|
||||
def test_resolve_external_user_raises_unauthorized_when_context_missing():
|
||||
data = _make_auth_data(tenant=None, app=MagicMock(), external_identity=ExternalIdentity(email="u@s.com"))
|
||||
with pytest.raises(Unauthorized):
|
||||
resolve_external_user(data)
|
||||
|
||||
|
||||
def test_load_app_access_mode_writes_mode():
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
app = MagicMock()
|
||||
app.id = "app-1"
|
||||
settings = MagicMock()
|
||||
settings.access_mode = "public"
|
||||
data = _make_auth_data(app=app)
|
||||
with patch(
|
||||
"controllers.openapi.auth.prepare.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
|
||||
return_value=settings,
|
||||
):
|
||||
load_app_access_mode(data)
|
||||
assert data.app_access_mode == WebAppAccessMode.PUBLIC
|
||||
|
||||
|
||||
def test_load_app_access_mode_writes_none_when_value_error():
|
||||
app = MagicMock()
|
||||
app.id = "app-1"
|
||||
data = _make_auth_data(app=app)
|
||||
with patch(
|
||||
"controllers.openapi.auth.prepare.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
|
||||
side_effect=ValueError("No data found."),
|
||||
):
|
||||
load_app_access_mode(data)
|
||||
assert data.app_access_mode is None
|
||||
|
||||
|
||||
def test_load_app_access_mode_no_op_when_app_missing():
|
||||
data = _make_auth_data()
|
||||
load_app_access_mode(data)
|
||||
assert data.app_access_mode is None
|
||||
@ -26,7 +26,7 @@ from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.openapi.auth.role_gate import require_workspace_role
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType, reset_auth_ctx, set_auth_ctx
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
# Tokens from `_seed`'s `set_auth_ctx` calls, drained after each test so a
|
||||
@ -55,7 +55,7 @@ def _account_ctx(account_id: uuid.UUID | None = None) -> AuthContext:
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_account",
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
@ -71,7 +71,7 @@ def _sso_ctx() -> AuthContext:
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.APPS_RUN}),
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_external_sso",
|
||||
token_type=TokenType.OAUTH_EXTERNAL_SSO,
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="h2",
|
||||
verified_tenants={},
|
||||
|
||||
@ -1,64 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import AppResolver
|
||||
from models import TenantStatus
|
||||
|
||||
|
||||
def _ctx(path_params: dict[str, str] | None) -> Context:
|
||||
return Context(required_scope="apps:run", path_params=path_params or {})
|
||||
|
||||
|
||||
def _app(*, status="normal", enable_api=True):
|
||||
return SimpleNamespace(id="app1", tenant_id="t1", status=status, enable_api=enable_api)
|
||||
|
||||
|
||||
def _tenant(*, status=TenantStatus.NORMAL):
|
||||
return SimpleNamespace(id="t1", status=status)
|
||||
|
||||
|
||||
def test_resolver_rejects_missing_path_param():
|
||||
with pytest.raises(BadRequest):
|
||||
AppResolver()(_ctx({}))
|
||||
|
||||
|
||||
def test_resolver_rejects_empty_path_params():
|
||||
# `Pipeline.guard` always seeds an empty dict when Flask reports no
|
||||
# view args, so a missing `app_id` key surfaces here as BadRequest.
|
||||
with pytest.raises(BadRequest):
|
||||
AppResolver()(_ctx(None))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_404_when_app_missing(db):
|
||||
db.session.get.side_effect = [None]
|
||||
with pytest.raises(NotFound):
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_403_when_disabled(db):
|
||||
db.session.get.side_effect = [_app(enable_api=False)]
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
assert "service_api_disabled" in str(exc.value.description)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_403_when_tenant_archived(db):
|
||||
db.session.get.side_effect = [_app(), _tenant(status=TenantStatus.ARCHIVE)]
|
||||
with pytest.raises(Forbidden):
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_populates_app_and_tenant(db):
|
||||
db.session.get.side_effect = [_app(), _tenant()]
|
||||
ctx = _ctx({"app_id": "x"})
|
||||
AppResolver()(ctx)
|
||||
assert ctx.app.id == "app1"
|
||||
assert ctx.tenant.id == "t1"
|
||||
@ -1,76 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import AppAuthzCheck
|
||||
from controllers.openapi.auth.strategies import AclStrategy, MembershipStrategy
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id="acc1"):
|
||||
c = Context(required_scope="apps:run")
|
||||
c.subject_type = subject_type
|
||||
c.subject_email = "alice@example.com"
|
||||
c.account_id = account_id
|
||||
c.app = SimpleNamespace(id="app1")
|
||||
c.tenant = SimpleNamespace(id="t1")
|
||||
return c
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies.EnterpriseService")
|
||||
def test_acl_strategy_private_calls_inner_api(ent):
|
||||
ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode="private")
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.return_value = True
|
||||
assert AclStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_called_once_with(
|
||||
user_id="acc1",
|
||||
app_id="app1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("access_mode", "subject_type", "expected"),
|
||||
[
|
||||
("public", SubjectType.ACCOUNT, True),
|
||||
("public", SubjectType.EXTERNAL_SSO, True),
|
||||
("sso_verified", SubjectType.ACCOUNT, True),
|
||||
("sso_verified", SubjectType.EXTERNAL_SSO, True),
|
||||
("private_all", SubjectType.ACCOUNT, True),
|
||||
("private_all", SubjectType.EXTERNAL_SSO, False),
|
||||
("private", SubjectType.EXTERNAL_SSO, False),
|
||||
],
|
||||
)
|
||||
@patch("controllers.openapi.auth.strategies.EnterpriseService")
|
||||
def test_acl_strategy_subject_mode_matrix(ent, access_mode, subject_type, expected):
|
||||
"""Step 1 matrix: subject vs access-mode compatibility. No inner API call expected."""
|
||||
ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode=access_mode)
|
||||
account_id = "acc1" if subject_type == SubjectType.ACCOUNT else None
|
||||
assert AclStrategy().authorize(_ctx(subject_type=subject_type, account_id=account_id)) is expected
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies.TenantService.account_belongs_to_tenant")
|
||||
@patch("controllers.openapi.auth.strategies.db")
|
||||
def test_membership_strategy_uses_join_lookup(db_mock, member):
|
||||
member.return_value = True
|
||||
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
|
||||
member.assert_called_once_with(db_mock.session, "acc1", "t1")
|
||||
|
||||
|
||||
def test_membership_strategy_rejects_external_sso():
|
||||
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)) is False
|
||||
|
||||
|
||||
def test_app_authz_check_raises_when_strategy_denies():
|
||||
deny = SimpleNamespace(authorize=lambda c: False)
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
AppAuthzCheck(lambda: deny)(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
assert "subject_no_app_access" in str(exc.value.description)
|
||||
|
||||
|
||||
def test_app_authz_check_passes_when_strategy_allows():
|
||||
allow = SimpleNamespace(authorize=lambda c: True)
|
||||
AppAuthzCheck(lambda: allow)(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
@ -1,83 +0,0 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import BearerCheck
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
InvalidBearerError,
|
||||
Scope,
|
||||
SubjectType,
|
||||
reset_auth_ctx,
|
||||
try_get_auth_ctx,
|
||||
)
|
||||
|
||||
|
||||
def _ctx(bearer_token: str | None) -> Context:
|
||||
return Context(required_scope="apps:run", bearer_token=bearer_token)
|
||||
|
||||
|
||||
def test_bearer_check_rejects_missing_header():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context(), pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx(None))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.get_authenticator")
|
||||
def test_bearer_check_rejects_unknown_prefix(get_auth):
|
||||
get_auth.return_value.authenticate.side_effect = InvalidBearerError("invalid_bearer")
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context(), pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx("xxx_abc"))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.get_authenticator")
|
||||
def test_bearer_check_populates_context_and_publishes_auth_ctx(get_auth):
|
||||
tok_id = uuid.uuid4()
|
||||
authn = AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
subject_email="a@x.com",
|
||||
subject_issuer=None,
|
||||
account_id=None,
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_id=tok_id,
|
||||
source="oauth-account",
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="hash-1",
|
||||
verified_tenants={},
|
||||
)
|
||||
get_auth.return_value.authenticate.return_value = authn
|
||||
|
||||
app = Flask(__name__)
|
||||
ctx = _ctx("dfoa_abc")
|
||||
with app.test_request_context():
|
||||
BearerCheck()(ctx)
|
||||
try:
|
||||
assert ctx.subject_type == SubjectType.ACCOUNT
|
||||
assert ctx.subject_email == "a@x.com"
|
||||
assert ctx.scopes == frozenset({Scope.FULL})
|
||||
assert ctx.source == "oauth-account"
|
||||
assert ctx.token_id == tok_id
|
||||
assert ctx.token_hash == "hash-1"
|
||||
# BearerCheck must also publish the same identity on the
|
||||
# openapi auth ContextVar so the surface gate + downstream
|
||||
# handlers don't see two different identity sources between
|
||||
# the decorator + pipeline paths. The reset token is parked
|
||||
# on `ctx.auth_ctx_reset_token` for `Pipeline.guard` to
|
||||
# consume in its `finally`.
|
||||
published = try_get_auth_ctx()
|
||||
assert published is authn
|
||||
assert published.client_id == "difyctl"
|
||||
assert ctx.auth_ctx_reset_token is not None
|
||||
finally:
|
||||
# In production `Pipeline.guard` resets the ContextVar; in
|
||||
# this isolated step-level test we reset it ourselves so the
|
||||
# value doesn't leak into the next test on the same worker.
|
||||
assert ctx.auth_ctx_reset_token is not None
|
||||
reset_auth_ctx(ctx.auth_ctx_reset_token)
|
||||
@ -1,157 +0,0 @@
|
||||
"""Unit tests for WorkspaceMembershipCheck (Layer 0)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import WorkspaceMembershipCheck
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id, tenant_id, cached_verified_tenants=None, token_hash=None) -> Context:
|
||||
c = Context(required_scope="apps:read")
|
||||
c.subject_type = subject_type
|
||||
c.account_id = account_id
|
||||
c.tenant = SimpleNamespace(id=tenant_id) if tenant_id else None
|
||||
c.cached_verified_tenants = cached_verified_tenants
|
||||
c.token_hash = token_hash
|
||||
return c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def step():
|
||||
return WorkspaceMembershipCheck()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_skips_when_enterprise_enabled(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = True
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id=str(uuid.uuid4()),
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx) # no raise
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_skips_for_external_sso(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
account_id=None,
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx) # no raise
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_uses_cached_ok(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={"t1": True},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx)
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_uses_cached_denied(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={"t1": False},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
step(ctx)
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_denies_when_no_membership(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
step(ctx)
|
||||
mock_record.assert_called_once_with("hash-1", "t1", False)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_denies_when_account_inactive(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.side_effect = [
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="banned")),
|
||||
]
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
step(ctx)
|
||||
mock_record.assert_called_once_with("hash-1", "t1", False)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_allows_active_member(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.side_effect = [
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="active")),
|
||||
]
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx) # no raise
|
||||
mock_record.assert_called_once_with("hash-1", "t1", True)
|
||||
@ -1,77 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import CallerMount
|
||||
from controllers.openapi.auth.strategies import AccountMounter, EndUserMounter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id=None, subject_email=None):
|
||||
c = Context(required_scope="apps:run")
|
||||
c.subject_type = subject_type
|
||||
c.account_id = account_id
|
||||
c.subject_email = subject_email
|
||||
c.app = SimpleNamespace(id="app1")
|
||||
c.tenant = SimpleNamespace(id="t1")
|
||||
return c
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._login_as")
|
||||
@patch("controllers.openapi.auth.strategies.db")
|
||||
def test_account_mounter(db, login):
|
||||
account = SimpleNamespace()
|
||||
db.session.get.return_value = account
|
||||
ctx = _ctx(subject_type=SubjectType.ACCOUNT, account_id="acc1")
|
||||
AccountMounter().mount(ctx)
|
||||
assert ctx.caller is account
|
||||
assert ctx.caller.current_tenant is ctx.tenant
|
||||
assert ctx.caller_kind == "account"
|
||||
login.assert_called_once_with(account)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._login_as")
|
||||
@patch("controllers.openapi.auth.strategies.EndUserService")
|
||||
def test_end_user_mounter(svc, login):
|
||||
eu = SimpleNamespace()
|
||||
svc.get_or_create_end_user_by_type.return_value = eu
|
||||
ctx = _ctx(subject_type=SubjectType.EXTERNAL_SSO, subject_email="a@x.com")
|
||||
EndUserMounter().mount(ctx)
|
||||
svc.get_or_create_end_user_by_type.assert_called_once_with(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id="t1",
|
||||
app_id="app1",
|
||||
user_id="a@x.com",
|
||||
)
|
||||
assert ctx.caller is eu
|
||||
assert ctx.caller_kind == "end_user"
|
||||
|
||||
|
||||
def test_caller_mount_dispatches_by_subject_type():
|
||||
seen = {}
|
||||
|
||||
class Fake:
|
||||
def __init__(self, st, tag):
|
||||
self._st, self._tag = st, tag
|
||||
|
||||
def applies_to(self, st):
|
||||
return st == self._st
|
||||
|
||||
def mount(self, ctx):
|
||||
seen["who"] = self._tag
|
||||
|
||||
cm = CallerMount(
|
||||
Fake(SubjectType.ACCOUNT, "acct"),
|
||||
Fake(SubjectType.EXTERNAL_SSO, "sso"),
|
||||
)
|
||||
cm(_ctx(subject_type=SubjectType.EXTERNAL_SSO))
|
||||
assert seen == {"who": "sso"}
|
||||
|
||||
|
||||
def test_caller_mount_raises_when_none_applies():
|
||||
with pytest.raises(Unauthorized):
|
||||
CallerMount()(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
@ -1,25 +0,0 @@
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import ScopeCheck
|
||||
|
||||
|
||||
def _ctx(scopes, required):
|
||||
c = Context(required_scope=required)
|
||||
c.scopes = frozenset(scopes)
|
||||
return c
|
||||
|
||||
|
||||
def test_scope_check_passes_on_full():
|
||||
ScopeCheck()(_ctx({"full"}, "apps:run"))
|
||||
|
||||
|
||||
def test_scope_check_passes_on_explicit_match():
|
||||
ScopeCheck()(_ctx({"apps:run"}, "apps:run"))
|
||||
|
||||
|
||||
def test_scope_check_rejects_when_missing():
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
ScopeCheck()(_ctx({"apps:read"}, "apps:run"))
|
||||
assert "insufficient_scope" in str(exc.value.description)
|
||||
@ -1,239 +0,0 @@
|
||||
"""Surface gate tests.
|
||||
|
||||
The gate has two attachment forms — decorator (`accept_subjects`) and
|
||||
pipeline step (`SurfaceCheck`) — and both must:
|
||||
- 403 on mismatched subject type with a canonical-path hint
|
||||
- emit `openapi.wrong_surface_denied` once with the right payload
|
||||
- pass-through on match
|
||||
- raise RuntimeError (not 403) if the auth ContextVar is unset — that's
|
||||
a wiring bug, not a user-driven failure
|
||||
|
||||
Identity is published via `libs.oauth_bearer.set_auth_ctx` / read with
|
||||
`try_get_auth_ctx`. Tests wrap the publish in a `_publish_auth_ctx`
|
||||
context manager so the ContextVar resets even when an assertion fails;
|
||||
that keeps state from leaking into the next test on the same worker.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import SurfaceCheck
|
||||
from controllers.openapi.auth.surface_gate import _coerce_subject_type, accept_subjects, check_surface
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType, reset_auth_ctx, set_auth_ctx
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _publish_auth_ctx(ctx: AuthContext) -> Iterator[None]:
|
||||
token = set_auth_ctx(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
reset_auth_ctx(token)
|
||||
|
||||
|
||||
def _account_ctx() -> AuthContext:
|
||||
return AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
subject_email="user@example.com",
|
||||
subject_issuer="dify:account",
|
||||
account_id=uuid.uuid4(),
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_account",
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
)
|
||||
|
||||
|
||||
def _sso_ctx() -> AuthContext:
|
||||
return AuthContext(
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
subject_email="sso@partner.com",
|
||||
subject_issuer="https://idp.partner.com",
|
||||
account_id=None,
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}),
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_external_sso",
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="h2",
|
||||
verified_tenants={},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_surface — shared core
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_check_surface_passes_when_subject_in_accepted():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_account_ctx()):
|
||||
check_surface(frozenset({SubjectType.ACCOUNT})) # no raise
|
||||
|
||||
|
||||
def test_check_surface_rejects_on_wrong_subject_and_emits_audit():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/permitted-external-apps"), _publish_auth_ctx(_account_ctx()):
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
check_surface(frozenset({SubjectType.EXTERNAL_SSO}))
|
||||
assert "wrong_surface" in exc.value.description
|
||||
# canonical-path hint should point at the caller's surface,
|
||||
# not the surface they were rejected from
|
||||
assert "/openapi/v1/apps" in exc.value.description
|
||||
emit.assert_called_once()
|
||||
kwargs = emit.call_args.kwargs
|
||||
assert kwargs["subject_type"] == SubjectType.ACCOUNT.value
|
||||
assert kwargs["attempted_path"] == "/openapi/v1/permitted-external-apps"
|
||||
assert kwargs["client_id"] == "difyctl"
|
||||
assert kwargs["token_id"] is not None
|
||||
|
||||
|
||||
def test_check_surface_rejects_sso_on_account_surface():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_sso_ctx()):
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
|
||||
with pytest.raises(Forbidden):
|
||||
check_surface(frozenset({SubjectType.ACCOUNT}))
|
||||
kwargs = emit.call_args.kwargs
|
||||
assert kwargs["subject_type"] == SubjectType.EXTERNAL_SSO.value
|
||||
|
||||
|
||||
def test_check_surface_runtime_error_when_auth_ctx_missing():
|
||||
"""Missing auth ContextVar means the bearer layer didn't run — wiring
|
||||
bug, not a user-driven failure. Surface as RuntimeError (loud) so a
|
||||
future refactor doesn't accidentally let a route skip authentication
|
||||
and return a 403 that looks identical to a legitimate wrong-surface
|
||||
deny.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps"):
|
||||
with pytest.raises(RuntimeError):
|
||||
check_surface(frozenset({SubjectType.ACCOUNT}))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# @accept_subjects — decorator form
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/account-only")
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def _account_only():
|
||||
return "ok"
|
||||
|
||||
@app.route("/external-only")
|
||||
@accept_subjects(SubjectType.EXTERNAL_SSO)
|
||||
def _external_only():
|
||||
return "ok"
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_accept_subjects_decorator_passes_on_match():
|
||||
app = _make_app()
|
||||
with app.test_request_context("/account-only"), _publish_auth_ctx(_account_ctx()):
|
||||
# Re-route through the decorated function by reaching for view_function
|
||||
view = app.view_functions["_account_only"]
|
||||
assert view() == "ok"
|
||||
|
||||
|
||||
def test_accept_subjects_decorator_403_on_miss():
|
||||
app = _make_app()
|
||||
with app.test_request_context("/external-only"), _publish_auth_ctx(_account_ctx()):
|
||||
view = app.view_functions["_external_only"]
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface"):
|
||||
with pytest.raises(Forbidden):
|
||||
view()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SurfaceCheck — pipeline step form
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _pipeline_ctx() -> Context:
|
||||
# SurfaceCheck reads ``request.path`` from Flask's global request — set up
|
||||
# via ``app.test_request_context`` in the calling tests — not from Context.
|
||||
return Context(required_scope=Scope.APPS_RUN)
|
||||
|
||||
|
||||
def test_surface_check_passes_on_match():
|
||||
step = SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT}))
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()):
|
||||
step(_pipeline_ctx()) # no raise
|
||||
|
||||
|
||||
def test_surface_check_rejects_on_miss_and_emits_audit():
|
||||
step = SurfaceCheck(accepted=frozenset({SubjectType.EXTERNAL_SSO}))
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()):
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
|
||||
with pytest.raises(Forbidden):
|
||||
step(_pipeline_ctx())
|
||||
emit.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _coerce_subject_type — normalises whatever sat on ctx.subject_type
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# The gate reads `ctx.subject_type` via `getattr(..., None)`, so the value
|
||||
# could be a real enum (happy path), a raw string (e.g. rehydrated from a
|
||||
# dict-shaped context), `None` (attribute missing), or something unexpected
|
||||
# from a buggy upstream. The coercer must collapse all of that to
|
||||
# `SubjectType | None` so `check_surface` can do a clean set-membership
|
||||
# check and emit a clean audit payload.
|
||||
|
||||
|
||||
def test_coerce_subject_type_returns_none_for_none():
|
||||
assert _coerce_subject_type(None) is None
|
||||
|
||||
|
||||
def test_coerce_subject_type_returns_enum_instance_unchanged():
|
||||
# Identity matters: we don't want to round-trip through the string
|
||||
# constructor for an already-valid enum.
|
||||
assert _coerce_subject_type(SubjectType.ACCOUNT) is SubjectType.ACCOUNT
|
||||
assert _coerce_subject_type(SubjectType.EXTERNAL_SSO) is SubjectType.EXTERNAL_SSO
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw", "expected"),
|
||||
[
|
||||
("account", SubjectType.ACCOUNT),
|
||||
("external_sso", SubjectType.EXTERNAL_SSO),
|
||||
],
|
||||
)
|
||||
def test_coerce_subject_type_parses_known_strings(raw: str, expected: SubjectType):
|
||||
assert _coerce_subject_type(raw) is expected
|
||||
|
||||
|
||||
def test_coerce_subject_type_raises_on_unknown_string():
|
||||
# Unknown strings reach `SubjectType(raw)` which raises ValueError.
|
||||
# We surface that loudly rather than silently returning None, because
|
||||
# a string that *looks* like a subject type but isn't is almost
|
||||
# certainly an upstream bug worth catching.
|
||||
with pytest.raises(ValueError):
|
||||
_coerce_subject_type("not_a_subject")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raw", [123, 1.5, b"account", object(), ["account"], {"account"}])
|
||||
def test_coerce_subject_type_returns_none_for_non_string_non_enum(raw: object):
|
||||
assert _coerce_subject_type(raw) is None
|
||||
142
api/tests/unit_tests/controllers/openapi/auth/test_verify.py
Normal file
142
api/tests/unit_tests/controllers/openapi/auth/test_verify.py
Normal file
@ -0,0 +1,142 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.verify import (
|
||||
check_acl,
|
||||
check_app_access,
|
||||
check_membership,
|
||||
check_private_app_permission,
|
||||
check_scope,
|
||||
)
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import Tenant
|
||||
from models.model import App
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
|
||||
def _data(**kwargs) -> AuthData:
|
||||
defaults: dict = {"token_type": TokenType.OAUTH_ACCOUNT, "token_hash": "hash", "scopes": frozenset({Scope.FULL})}
|
||||
defaults.update(kwargs)
|
||||
return AuthData(**defaults)
|
||||
|
||||
|
||||
def test_check_scope_passes_when_required_is_none():
|
||||
check_scope(_data(required_scope=None))
|
||||
|
||||
|
||||
def test_check_scope_passes_when_full_in_scopes():
|
||||
check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.FULL})))
|
||||
|
||||
|
||||
def test_check_scope_passes_when_exact_scope_present():
|
||||
check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.APPS_RUN})))
|
||||
|
||||
|
||||
def test_check_scope_raises_forbidden_when_scope_missing():
|
||||
with pytest.raises(Forbidden, match="insufficient_scope"):
|
||||
check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.APPS_READ})))
|
||||
|
||||
|
||||
def test_check_membership_raises_unauthorized_when_tenant_none():
|
||||
with pytest.raises(Unauthorized):
|
||||
check_membership(_data(tenant=None))
|
||||
|
||||
|
||||
def test_check_membership_calls_check_workspace_membership():
|
||||
tenant = MagicMock(spec=Tenant)
|
||||
tenant.id = "tenant-1"
|
||||
data = _data(
|
||||
account_id=uuid.uuid4(),
|
||||
token_hash="myhash",
|
||||
tenants={"tenant-1": True},
|
||||
tenant=tenant,
|
||||
)
|
||||
with patch("controllers.openapi.auth.verify.check_workspace_membership") as mock_cwm:
|
||||
check_membership(data)
|
||||
mock_cwm.assert_called_once_with(
|
||||
account_id=data.account_id,
|
||||
tenant_id="tenant-1",
|
||||
token_hash="myhash",
|
||||
membership_cache=data.tenants,
|
||||
)
|
||||
|
||||
|
||||
def test_check_app_access_passes_when_tenant_none():
|
||||
check_app_access(_data(tenant=None))
|
||||
|
||||
|
||||
def test_check_app_access_passes_when_member():
|
||||
tenant = MagicMock(spec=Tenant)
|
||||
tenant.id = "t1"
|
||||
data = _data(account_id=uuid.uuid4(), tenant=tenant)
|
||||
with patch("controllers.openapi.auth.verify.TenantService.account_belongs_to_tenant", return_value=True):
|
||||
check_app_access(data)
|
||||
|
||||
|
||||
def test_check_app_access_raises_when_not_member():
|
||||
tenant = MagicMock(spec=Tenant)
|
||||
tenant.id = "t1"
|
||||
data = _data(account_id=uuid.uuid4(), tenant=tenant)
|
||||
with patch("controllers.openapi.auth.verify.TenantService.account_belongs_to_tenant", return_value=False):
|
||||
with pytest.raises(Forbidden, match="subject_no_app_access"):
|
||||
check_app_access(data)
|
||||
|
||||
|
||||
def test_check_acl_raises_when_app_or_mode_missing():
|
||||
with pytest.raises(Forbidden):
|
||||
check_acl(_data(app=None, app_access_mode=None))
|
||||
|
||||
|
||||
def test_check_acl_account_allowed_for_public():
|
||||
app = MagicMock(spec=App)
|
||||
data = _data(token_type=TokenType.OAUTH_ACCOUNT, app=app, app_access_mode=WebAppAccessMode.PUBLIC)
|
||||
check_acl(data)
|
||||
|
||||
|
||||
def test_check_acl_external_sso_blocked_for_private():
|
||||
app = MagicMock(spec=App)
|
||||
data = _data(
|
||||
token_type=TokenType.OAUTH_EXTERNAL_SSO,
|
||||
app=app,
|
||||
app_access_mode=WebAppAccessMode.PRIVATE,
|
||||
)
|
||||
with pytest.raises(Forbidden, match="subject_not_allowed_for_access_mode"):
|
||||
check_acl(data)
|
||||
|
||||
|
||||
def test_check_acl_external_sso_allowed_for_sso_verified():
|
||||
app = MagicMock(spec=App)
|
||||
data = _data(
|
||||
token_type=TokenType.OAUTH_EXTERNAL_SSO,
|
||||
app=app,
|
||||
app_access_mode=WebAppAccessMode.SSO_VERIFIED,
|
||||
)
|
||||
check_acl(data)
|
||||
|
||||
|
||||
def test_check_private_app_permission_raises_when_app_none():
|
||||
with pytest.raises(Forbidden):
|
||||
check_private_app_permission(_data(app=None))
|
||||
|
||||
|
||||
def test_check_private_app_permission_raises_when_user_not_allowed():
|
||||
app = MagicMock(spec=App)
|
||||
app.id = "app-1"
|
||||
data = _data(account_id=uuid.uuid4(), app=app)
|
||||
target = "controllers.openapi.auth.verify.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp"
|
||||
with patch(target, return_value=False):
|
||||
with pytest.raises(Forbidden, match="user_not_allowed_for_private_app"):
|
||||
check_private_app_permission(data)
|
||||
|
||||
|
||||
def test_check_private_app_permission_passes_when_allowed():
|
||||
app = MagicMock(spec=App)
|
||||
app.id = "app-1"
|
||||
data = _data(account_id=uuid.uuid4(), app=app)
|
||||
target = "controllers.openapi.auth.verify.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp"
|
||||
with patch(target, return_value=True):
|
||||
check_private_app_permission(data)
|
||||
@ -1,20 +1,36 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.pipeline import PipelineRouter
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
|
||||
|
||||
def _stub_execute(self, args, kwargs, view, *, scope=None, allowed_token_types=None, edition=None):
|
||||
"""Bypass all auth logic; inject minimal AuthData and call the view directly."""
|
||||
kwargs["auth_data"] = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
account_id=uuid.uuid4(),
|
||||
token_hash="test",
|
||||
token_id=uuid.uuid4(),
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
required_scope=scope,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bypass_pipeline(monkeypatch):
|
||||
"""Stub Pipeline.run so endpoint decoration does not invoke real auth.
|
||||
"""Stub PipelineRouter._execute so endpoints skip real auth at request time.
|
||||
|
||||
Module-level @OAUTH_BEARER_PIPELINE.guard(...) captures the real
|
||||
pipeline at import time; mocking the module attribute does not undo
|
||||
that. Patching Pipeline.run on the class is the bypass that actually
|
||||
works.
|
||||
Module-level @auth_router.guard(...) captures the real router at import
|
||||
time — patching guard itself does nothing. Patching _execute on the class
|
||||
is the seam that fires at request time.
|
||||
"""
|
||||
monkeypatch.setattr(Pipeline, "run", lambda self, ctx: None)
|
||||
monkeypatch.setattr(PipelineRouter, "_execute", _stub_execute)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -86,7 +86,7 @@ def test_subject_match_for_account_filters_by_account_id():
|
||||
"""Account subject scopes queries via account_id."""
|
||||
import uuid as _uuid
|
||||
|
||||
from libs.oauth_bearer import AuthContext, SubjectType
|
||||
from libs.oauth_bearer import AuthContext, SubjectType, TokenType
|
||||
from services.oauth_device_flow import subject_match_clauses
|
||||
|
||||
aid = _uuid.uuid4()
|
||||
@ -98,7 +98,7 @@ def test_subject_match_for_account_filters_by_account_id():
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({"full"}),
|
||||
token_id=_uuid.uuid4(),
|
||||
source="oauth_account",
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
@ -116,7 +116,7 @@ def test_subject_match_for_external_sso_filters_by_email_and_issuer():
|
||||
"""
|
||||
import uuid as _uuid
|
||||
|
||||
from libs.oauth_bearer import AuthContext, SubjectType
|
||||
from libs.oauth_bearer import AuthContext, SubjectType, TokenType
|
||||
from services.oauth_device_flow import subject_match_clauses
|
||||
|
||||
ctx = AuthContext(
|
||||
@ -127,7 +127,7 @@ def test_subject_match_for_external_sso_filters_by_email_and_issuer():
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({"apps:run"}),
|
||||
token_id=_uuid.uuid4(),
|
||||
source="oauth_external_sso",
|
||||
token_type=TokenType.OAUTH_EXTERNAL_SSO,
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
|
||||
@ -57,7 +57,11 @@ def test_stop_task_endpoint_registered(openapi_app):
|
||||
|
||||
|
||||
def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, monkeypatch):
|
||||
import uuid
|
||||
|
||||
from controllers.openapi.app_run import AppRunTaskStopApi
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
|
||||
queue_mock = Mock()
|
||||
graph_mock = Mock()
|
||||
@ -69,15 +73,23 @@ def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, mo
|
||||
monkeypatch.setattr(run_module, "GraphEngineManager", graph_mock)
|
||||
monkeypatch.setattr(run_module, "redis_client", object())
|
||||
|
||||
auth_data = AuthData.model_construct(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
account_id=uuid.uuid4(),
|
||||
token_hash="test",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
app=SimpleNamespace(id="app-1", tenant_id="t-1"),
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
api = AppRunTaskStopApi()
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/tasks/task-1/stop", method="POST"):
|
||||
result = api.post.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
task_id="task-1",
|
||||
app_model=SimpleNamespace(id="app-1", tenant_id="t-1"),
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
auth_data=auth_data,
|
||||
)
|
||||
|
||||
queue_mock.set_stop_flag_no_user_check.assert_called_once_with("task-1")
|
||||
|
||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
@ -11,9 +12,23 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.human_input import RecipientType
|
||||
|
||||
|
||||
def _make_auth_data(app_model, caller, caller_kind):
|
||||
return AuthData.model_construct(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
account_id=uuid.uuid4(),
|
||||
token_hash="test",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
app=app_model,
|
||||
caller=caller,
|
||||
caller_kind=caller_kind,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenApiHumanInputFormGet:
|
||||
def test_get_success(self, app, bypass_pipeline, monkeypatch):
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
@ -43,15 +58,14 @@ class TestOpenApiHumanInputFormGet:
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
|
||||
resp = api.get.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
|
||||
payload = json.loads(resp.get_data(as_text=True))
|
||||
@ -71,6 +85,7 @@ class TestOpenApiHumanInputFormGet:
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/bad"):
|
||||
with pytest.raises(NotFound):
|
||||
@ -78,9 +93,7 @@ class TestOpenApiHumanInputFormGet:
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="bad",
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
|
||||
def test_get_form_wrong_app(self, app, bypass_pipeline, monkeypatch):
|
||||
@ -97,6 +110,7 @@ class TestOpenApiHumanInputFormGet:
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
|
||||
with pytest.raises(NotFound):
|
||||
@ -104,9 +118,7 @@ class TestOpenApiHumanInputFormGet:
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
|
||||
def test_get_form_wrong_surface(self, app, bypass_pipeline, monkeypatch):
|
||||
@ -126,6 +138,7 @@ class TestOpenApiHumanInputFormGet:
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
|
||||
with pytest.raises(NotFound):
|
||||
@ -133,9 +146,7 @@ class TestOpenApiHumanInputFormGet:
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
|
||||
|
||||
@ -172,9 +183,7 @@ class TestOpenApiHumanInputFormPost:
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
app_model=app_model,
|
||||
caller=caller,
|
||||
caller_kind="account",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
|
||||
service_mock.submit_form_by_token.assert_called_once_with(
|
||||
@ -211,9 +220,7 @@ class TestOpenApiHumanInputFormPost:
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
app_model=app_model,
|
||||
caller=caller,
|
||||
caller_kind="end_user",
|
||||
auth_data=_make_auth_data(app_model, caller, "end_user"),
|
||||
)
|
||||
|
||||
service_mock.submit_form_by_token.assert_called_once_with(
|
||||
|
||||
@ -3,15 +3,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
def _make_auth_data(app_model, caller, caller_kind):
|
||||
return AuthData.model_construct(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
account_id=uuid.uuid4(),
|
||||
token_hash="test",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
app=app_model,
|
||||
caller=caller,
|
||||
caller_kind=caller_kind,
|
||||
)
|
||||
|
||||
|
||||
def _make_workflow_run(
|
||||
*,
|
||||
app_id="app-1",
|
||||
@ -50,6 +65,7 @@ class TestOpenApiWorkflowEventsApi:
|
||||
from models.model import AppMode
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
|
||||
with pytest.raises(NotFound):
|
||||
@ -57,9 +73,7 @@ class TestOpenApiWorkflowEventsApi:
|
||||
api,
|
||||
app_id="app-1",
|
||||
task_id="wf-run-1",
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
|
||||
def test_not_found_when_run_belongs_to_different_app(self, app, bypass_pipeline, monkeypatch):
|
||||
@ -77,6 +91,7 @@ class TestOpenApiWorkflowEventsApi:
|
||||
from models.model import AppMode
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
|
||||
with pytest.raises(NotFound):
|
||||
@ -84,9 +99,7 @@ class TestOpenApiWorkflowEventsApi:
|
||||
api,
|
||||
app_id="app-1",
|
||||
task_id="wf-run-1",
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
|
||||
def test_account_caller_checks_created_by_account(self, app, bypass_pipeline, monkeypatch):
|
||||
@ -115,6 +128,7 @@ class TestOpenApiWorkflowEventsApi:
|
||||
from models.model import AppMode
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
api = self._get_api()
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
|
||||
@ -123,9 +137,7 @@ class TestOpenApiWorkflowEventsApi:
|
||||
api,
|
||||
app_id="app-1",
|
||||
task_id="wf-run-1",
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
assert resp.mimetype == "text/event-stream"
|
||||
|
||||
@ -143,6 +155,7 @@ class TestOpenApiWorkflowEventsApi:
|
||||
from models.model import AppMode
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
api = self._get_api()
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
|
||||
@ -151,9 +164,7 @@ class TestOpenApiWorkflowEventsApi:
|
||||
api,
|
||||
app_id="app-1",
|
||||
task_id="wf-run-1",
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
|
||||
def test_end_user_caller_checks_created_by_end_user(self, app, bypass_pipeline, monkeypatch):
|
||||
@ -179,6 +190,7 @@ class TestOpenApiWorkflowEventsApi:
|
||||
from models.model import AppMode
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
caller = SimpleNamespace(id="eu-1")
|
||||
|
||||
api = self._get_api()
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
|
||||
@ -186,9 +198,7 @@ class TestOpenApiWorkflowEventsApi:
|
||||
api,
|
||||
app_id="app-1",
|
||||
task_id="wf-run-1",
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="eu-1"),
|
||||
caller_kind="end_user",
|
||||
auth_data=_make_auth_data(app_model, caller, "end_user"),
|
||||
)
|
||||
assert resp.mimetype == "text/event-stream"
|
||||
|
||||
@ -222,6 +232,7 @@ class TestOpenApiWorkflowEventsApi:
|
||||
from models.model import AppMode
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
api = self._get_api()
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
|
||||
@ -229,9 +240,7 @@ class TestOpenApiWorkflowEventsApi:
|
||||
api,
|
||||
app_id="app-1",
|
||||
task_id="wf-run-1",
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
assert resp.mimetype == "text/event-stream"
|
||||
chunks = list(resp.response)
|
||||
|
||||
@ -38,7 +38,7 @@ from controllers.openapi.workspaces import (
|
||||
WorkspaceMembersApi,
|
||||
WorkspaceSwitchApi,
|
||||
)
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType, reset_auth_ctx, set_auth_ctx
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from services.errors.account import (
|
||||
AccountAlreadyInTenantError,
|
||||
@ -97,13 +97,25 @@ def _auth_ctx(account_id: uuid.UUID | None = None) -> AuthContext:
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_account",
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="h",
|
||||
verified_tenants={},
|
||||
)
|
||||
|
||||
|
||||
def _auth_data(account_id: uuid.UUID) -> AuthData:
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
|
||||
return AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
account_id=account_id,
|
||||
token_hash="testhash",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
)
|
||||
|
||||
|
||||
def _account(account_id: str = "acct-1", email: str = "u@example.com") -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=account_id,
|
||||
@ -256,7 +268,7 @@ def test_switch_returns_workspace_detail_with_current_true(app, bypass_pipeline,
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/switch", method="POST"):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
body, status = api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
|
||||
body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
assert status == 200
|
||||
assert body["id"] == ws_id
|
||||
@ -284,7 +296,7 @@ def test_switch_404s_when_service_raises_account_not_link_tenant(app, bypass_pip
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/switch", method="POST"):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(NotFound):
|
||||
api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
|
||||
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -318,7 +330,7 @@ def test_members_list_returns_normalized_rows(app, bypass_pipeline, monkeypatch)
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members"):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
body, status = api.get.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
|
||||
body, status = api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
assert status == 200
|
||||
assert body["page"] == 1
|
||||
@ -360,7 +372,7 @@ def test_members_list_paginates_with_query_params(app, bypass_pipeline, monkeypa
|
||||
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members?page=2&limit=2"):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
body, status = api.get.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
|
||||
body, status = api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
assert status == 200
|
||||
assert body["page"] == 2
|
||||
@ -383,7 +395,7 @@ def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypa
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members?pg=2"):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(BadRequest):
|
||||
api.get.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
|
||||
api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -421,7 +433,7 @@ def test_invite_happy_path_returns_invite_url_and_member_id(app, bypass_pipeline
|
||||
content_type="application/json",
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
body, status = api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
|
||||
body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
assert status == 201
|
||||
assert body["result"] == "success"
|
||||
@ -506,7 +518,7 @@ def test_invite_blocked_by_saas_members_cap(app, bypass_pipeline, monkeypatch):
|
||||
with _invite_request(app, ws_id, acct_id):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(Forbidden) as exc_info:
|
||||
api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
|
||||
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
body = exc_info.value.response.json
|
||||
assert body["code"] == "members.limit_exceeded"
|
||||
@ -552,7 +564,7 @@ def test_invite_blocked_by_ee_workspace_members_license(app, bypass_pipeline, mo
|
||||
with _invite_request(app, ws_id, acct_id):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(Forbidden) as exc_info:
|
||||
api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
|
||||
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
body = exc_info.value.response.json
|
||||
assert body["code"] == "workspace_members.license_exceeded"
|
||||
@ -591,7 +603,7 @@ def test_invite_ce_passes_when_both_caps_disabled(app, bypass_pipeline, monkeypa
|
||||
|
||||
with _invite_request(app, ws_id, acct_id):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
body, status = api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
|
||||
body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
assert status == 201
|
||||
assert body["email"] == "new@example.com"
|
||||
@ -620,7 +632,7 @@ def test_invite_400_when_already_in_tenant(app, bypass_pipeline, monkeypatch):
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(BadRequest):
|
||||
api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
|
||||
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -653,10 +665,8 @@ def test_delete_member_happy_path(app, bypass_pipeline, monkeypatch):
|
||||
method="DELETE",
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
body, status = api.delete.__wrapped__.__wrapped__.__wrapped__(
|
||||
api,
|
||||
workspace_id=ws_id,
|
||||
member_id=member_id,
|
||||
body, status = api.delete.__wrapped__.__wrapped__(
|
||||
api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id)
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
@ -697,10 +707,11 @@ def test_delete_member_exception_mapping(app, bypass_pipeline, monkeypatch, exc,
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(expected):
|
||||
api.delete.__wrapped__.__wrapped__.__wrapped__(
|
||||
api.delete.__wrapped__.__wrapped__(
|
||||
api,
|
||||
workspace_id=ws_id,
|
||||
member_id=member_id,
|
||||
auth_data=_auth_data(acct_id),
|
||||
)
|
||||
|
||||
|
||||
@ -723,10 +734,11 @@ def test_delete_member_404_when_member_missing(app, bypass_pipeline, monkeypatch
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(NotFound):
|
||||
api.delete.__wrapped__.__wrapped__.__wrapped__(
|
||||
api.delete.__wrapped__.__wrapped__(
|
||||
api,
|
||||
workspace_id=ws_id,
|
||||
member_id=member_id,
|
||||
auth_data=_auth_data(acct_id),
|
||||
)
|
||||
|
||||
|
||||
@ -762,10 +774,8 @@ def test_update_role_happy_path(app, bypass_pipeline, monkeypatch):
|
||||
content_type="application/json",
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
body, status = api.put.__wrapped__.__wrapped__.__wrapped__(
|
||||
api,
|
||||
workspace_id=ws_id,
|
||||
member_id=member_id,
|
||||
body, status = api.put.__wrapped__.__wrapped__(
|
||||
api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id)
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
@ -810,10 +820,11 @@ def test_update_role_exception_mapping(app, bypass_pipeline, monkeypatch, exc, e
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(expected):
|
||||
api.put.__wrapped__.__wrapped__.__wrapped__(
|
||||
api.put.__wrapped__.__wrapped__(
|
||||
api,
|
||||
workspace_id=ws_id,
|
||||
member_id=member_id,
|
||||
auth_data=_auth_data(acct_id),
|
||||
)
|
||||
|
||||
|
||||
@ -847,9 +858,8 @@ def test_non_member_caller_gets_404_on_switch(app, bypass_pipeline, monkeypatch)
|
||||
# Strip only the bearer + surface-gate wrappers; keep the role gate.
|
||||
# Decorator stack (innermost → outermost):
|
||||
# role_gate → accept_subjects → validate_bearer
|
||||
# So `post.__wrapped__` unwraps validate_bearer; we then unwrap
|
||||
# accept_subjects to land on the role-gate wrapper.
|
||||
gated = api.post.__wrapped__.__wrapped__
|
||||
# `post.__wrapped__` is now the role-gate wrapper directly (auth_router.guard is the only outer wrapper).
|
||||
gated = api.post.__wrapped__
|
||||
with pytest.raises(NotFound):
|
||||
gated(api, workspace_id=ws_id)
|
||||
|
||||
@ -881,7 +891,7 @@ def test_load_tenant_rejects_archived_workspace(app, bypass_pipeline, monkeypatc
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members"):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(NotFound):
|
||||
api.get.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
|
||||
api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -915,4 +925,4 @@ def test_invite_400_when_register_error(app, bypass_pipeline, monkeypatch):
|
||||
):
|
||||
_seed(_auth_ctx(account_id=acct_id))
|
||||
with pytest.raises(BadRequest):
|
||||
api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
|
||||
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
|
||||
from clients.agent_backend import (
|
||||
AgentBackendRunEventAdapter,
|
||||
AgentBackendStreamInternalEvent,
|
||||
CleanupLayerSpec,
|
||||
FakeAgentBackendRunClient,
|
||||
FakeAgentBackendScenario,
|
||||
)
|
||||
@ -13,9 +16,10 @@ from core.workflow.nodes.agent_v2.binding_resolver import WorkflowAgentBindingBu
|
||||
from core.workflow.nodes.agent_v2.entities import DifyAgentNodeData
|
||||
from core.workflow.nodes.agent_v2.output_adapter import WorkflowAgentOutputAdapter
|
||||
from core.workflow.nodes.agent_v2.runtime_request_builder import WorkflowAgentRuntimeRequestBuilder
|
||||
from core.workflow.nodes.agent_v2.session_store import WorkflowAgentRuntimeSessionStore, WorkflowAgentSessionScope
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import StreamCompletedEvent
|
||||
from graphon.node_events import PauseRequestedEvent, StreamCompletedEvent
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from graphon.variables.segments import StringSegment
|
||||
from models.agent import Agent, AgentConfigSnapshot, WorkflowAgentNodeBinding
|
||||
@ -84,7 +88,47 @@ class FakeBindingResolver(WorkflowAgentBindingResolver):
|
||||
return WorkflowAgentBindingBundle(binding=self.binding, agent=self.agent, snapshot=self.snapshot)
|
||||
|
||||
|
||||
def _node(*, scenario: FakeAgentBackendScenario = FakeAgentBackendScenario.SUCCESS) -> DifyAgentNode:
|
||||
class FakeSessionStore:
|
||||
def __init__(self, snapshot: CompositorSessionSnapshot | None = None) -> None:
|
||||
self.loaded_snapshot = snapshot
|
||||
self.saved: list[
|
||||
tuple[
|
||||
WorkflowAgentSessionScope,
|
||||
str,
|
||||
CompositorSessionSnapshot | None,
|
||||
list[CleanupLayerSpec],
|
||||
]
|
||||
] = []
|
||||
self.cleaned: list[tuple[WorkflowAgentSessionScope, str | None]] = []
|
||||
|
||||
def load_active_snapshot(self, scope: WorkflowAgentSessionScope) -> CompositorSessionSnapshot | None:
|
||||
return self.loaded_snapshot
|
||||
|
||||
def save_active_snapshot(
|
||||
self,
|
||||
*,
|
||||
scope: WorkflowAgentSessionScope,
|
||||
backend_run_id: str,
|
||||
snapshot: CompositorSessionSnapshot | None,
|
||||
composition_layer_specs: list[CleanupLayerSpec],
|
||||
) -> None:
|
||||
self.saved.append((scope, backend_run_id, snapshot, list(composition_layer_specs)))
|
||||
|
||||
def mark_cleaned(
|
||||
self,
|
||||
*,
|
||||
scope: WorkflowAgentSessionScope,
|
||||
backend_run_id: str | None = None,
|
||||
) -> None:
|
||||
self.cleaned.append((scope, backend_run_id))
|
||||
|
||||
|
||||
def _node(
|
||||
*,
|
||||
scenario: FakeAgentBackendScenario = FakeAgentBackendScenario.SUCCESS,
|
||||
agent_backend_client: FakeAgentBackendRunClient | None = None,
|
||||
session_store: FakeSessionStore | None = None,
|
||||
) -> DifyAgentNode:
|
||||
graph_init_params = GraphInitParams(
|
||||
workflow_id="workflow-1",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
@ -106,6 +150,7 @@ def _node(*, scenario: FakeAgentBackendScenario = FakeAgentBackendScenario.SUCCE
|
||||
def is_owned_by_tenant(self, *, file_id: str, tenant_id: str) -> bool:
|
||||
return True
|
||||
|
||||
client = agent_backend_client or FakeAgentBackendRunClient(scenario=scenario)
|
||||
return DifyAgentNode(
|
||||
node_id="agent-node",
|
||||
data=DifyAgentNodeData.model_validate({"type": BuiltinNodeTypes.AGENT, "version": "2"}),
|
||||
@ -113,11 +158,12 @@ def _node(*, scenario: FakeAgentBackendScenario = FakeAgentBackendScenario.SUCCE
|
||||
graph_runtime_state=cast(GraphRuntimeState, SimpleNamespace(variable_pool=FakeVariablePool())),
|
||||
binding_resolver=FakeBindingResolver(),
|
||||
runtime_request_builder=WorkflowAgentRuntimeRequestBuilder(credentials_provider=FakeCredentialsProvider()),
|
||||
agent_backend_client=FakeAgentBackendRunClient(scenario=scenario),
|
||||
agent_backend_client=client,
|
||||
event_adapter=AgentBackendRunEventAdapter(),
|
||||
output_adapter=WorkflowAgentOutputAdapter(),
|
||||
type_checker=PerOutputTypeChecker(file_validator=_AlwaysAllowFileValidator()),
|
||||
failure_orchestrator=OutputFailureOrchestrator(),
|
||||
session_store=cast(WorkflowAgentRuntimeSessionStore | None, session_store),
|
||||
)
|
||||
|
||||
|
||||
@ -132,7 +178,7 @@ def test_agent_node_run_maps_successful_agent_backend_run_to_node_result():
|
||||
assert agent_log["agent_backend"]["run_id"] == "fake-run-1"
|
||||
assert agent_log["agent_backend"]["status"] == "succeeded"
|
||||
assert result.process_data["agent_id"] == "agent-1"
|
||||
assert result.inputs["agent_backend_request"]["composition"]["layers"][4]["config"]["credentials"] == "[REDACTED]"
|
||||
assert result.inputs["agent_backend_request"]["composition"]["layers"][5]["config"]["credentials"] == "[REDACTED]"
|
||||
|
||||
|
||||
def test_agent_node_run_maps_failed_agent_backend_run_to_node_result():
|
||||
@ -145,6 +191,126 @@ def test_agent_node_run_maps_failed_agent_backend_run_to_node_result():
|
||||
assert result.error_type == "unit_test"
|
||||
|
||||
|
||||
def test_agent_node_failed_run_marks_session_cleaned_to_prevent_stale_reuse():
|
||||
"""A failed agent run must retire the local ACTIVE session row so a workflow
|
||||
loop back into the same Agent node does not resume from a stale snapshot."""
|
||||
existing_snapshot = CompositorSessionSnapshot(layers=[])
|
||||
store = FakeSessionStore(snapshot=existing_snapshot)
|
||||
|
||||
events = list(_node(scenario=FakeAgentBackendScenario.FAILED, session_store=store)._run())
|
||||
|
||||
assert len(events) == 1
|
||||
assert store.cleaned, "failed agent run should mark the session cleaned"
|
||||
cleaned_scope, cleaned_backend_run_id = store.cleaned[0]
|
||||
assert cleaned_scope.workflow_run_id == "workflow-run-1"
|
||||
assert cleaned_backend_run_id == "fake-run-1"
|
||||
# A failed run does not produce a fresh snapshot to persist.
|
||||
assert store.saved == []
|
||||
|
||||
|
||||
def test_agent_node_saves_success_snapshot_and_reuses_existing_snapshot():
|
||||
existing_snapshot = CompositorSessionSnapshot(layers=[])
|
||||
store = FakeSessionStore(snapshot=existing_snapshot)
|
||||
client = FakeAgentBackendRunClient()
|
||||
node = _node(agent_backend_client=client, session_store=store)
|
||||
|
||||
events = list(node._run())
|
||||
|
||||
assert len(events) == 1
|
||||
assert store.saved
|
||||
scope, backend_run_id, saved_snapshot, saved_specs = store.saved[0]
|
||||
assert scope.workflow_run_id == "workflow-run-1"
|
||||
assert backend_run_id == "fake-run-1"
|
||||
assert saved_snapshot is not None
|
||||
assert client.request is not None
|
||||
assert client.request.session_snapshot is existing_snapshot
|
||||
# Persist enough composition shape to replay a cleanup run; plugin layers
|
||||
# (which would carry credentials) are intentionally absent.
|
||||
saved_layer_names = [spec.name for spec in saved_specs]
|
||||
assert saved_layer_names, "cleanup specs must persist at least the non-plugin layers"
|
||||
plugin_types = {"dify.plugin.llm", "dify.plugin.tools"}
|
||||
assert not {spec.type for spec in saved_specs} & plugin_types
|
||||
|
||||
|
||||
def test_agent_node_run_when_session_store_save_raises_records_persist_error_in_metadata():
|
||||
"""A DB-side write failure must not crash the node; it should set
|
||||
``session_snapshot_persist_error`` in the agent_backend metadata so the
|
||||
incident is observable from the workflow_node_executions record."""
|
||||
|
||||
class _ExplodingSessionStore(FakeSessionStore):
|
||||
def save_active_snapshot(self, **kwargs): # type: ignore[override]
|
||||
del kwargs
|
||||
raise RuntimeError("simulated DB failure")
|
||||
|
||||
store = _ExplodingSessionStore()
|
||||
events = list(_node(session_store=store)._run())
|
||||
|
||||
assert len(events) == 1
|
||||
result = cast(StreamCompletedEvent, events[0]).node_run_result
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
agent_backend = result.metadata[WorkflowNodeExecutionMetadataKey.AGENT_LOG]["agent_backend"]
|
||||
assert agent_backend["session_snapshot_persisted"] is False
|
||||
assert agent_backend["session_snapshot_persist_error"] == "workflow_agent_runtime_session_store_error"
|
||||
|
||||
|
||||
def test_agent_node_failed_run_when_mark_cleaned_raises_records_cleanup_error_in_metadata():
|
||||
"""Same defensive pattern: a DB-side mark_cleaned failure must surface as
|
||||
a ``session_snapshot_cleanup_error`` in metadata, not as a node crash."""
|
||||
|
||||
class _ExplodingMarkCleanedStore(FakeSessionStore):
|
||||
def mark_cleaned(self, **kwargs): # type: ignore[override]
|
||||
del kwargs
|
||||
raise RuntimeError("simulated DB failure")
|
||||
|
||||
store = _ExplodingMarkCleanedStore()
|
||||
events = list(_node(scenario=FakeAgentBackendScenario.FAILED, session_store=store)._run())
|
||||
|
||||
assert len(events) == 1
|
||||
result = cast(StreamCompletedEvent, events[0]).node_run_result
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
agent_backend = result.metadata[WorkflowNodeExecutionMetadataKey.AGENT_LOG]["agent_backend"]
|
||||
assert agent_backend["session_snapshot_cleaned_on_failure"] is False
|
||||
assert agent_backend["session_snapshot_cleanup_error"] == "workflow_agent_runtime_session_store_error"
|
||||
|
||||
|
||||
def test_agent_node_success_run_without_session_store_skips_persistence():
|
||||
"""When ``session_store`` is None the node still completes successfully —
|
||||
the lifecycle branch is a no-op and the run result is unaffected."""
|
||||
events = list(_node(session_store=None)._run())
|
||||
|
||||
assert len(events) == 1
|
||||
result = cast(StreamCompletedEvent, events[0]).node_run_result
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
agent_backend = result.metadata[WorkflowNodeExecutionMetadataKey.AGENT_LOG]["agent_backend"]
|
||||
# No persistence metadata is attached when the store is missing.
|
||||
assert "session_snapshot_persisted" not in agent_backend
|
||||
|
||||
|
||||
def test_agent_node_failed_run_without_session_store_skips_mark_cleaned():
|
||||
"""``session_store=None`` + failed terminal must remain a no-op for
|
||||
the cleanup branch — the node failure path still surfaces correctly."""
|
||||
events = list(_node(scenario=FakeAgentBackendScenario.FAILED, session_store=None)._run())
|
||||
|
||||
assert len(events) == 1
|
||||
result = cast(StreamCompletedEvent, events[0]).node_run_result
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
agent_backend = result.metadata[WorkflowNodeExecutionMetadataKey.AGENT_LOG]["agent_backend"]
|
||||
assert "session_snapshot_cleaned_on_failure" not in agent_backend
|
||||
|
||||
|
||||
def test_agent_node_paused_run_requests_workflow_pause_and_persists_snapshot():
|
||||
store = FakeSessionStore()
|
||||
node = _node(scenario=FakeAgentBackendScenario.PAUSED, session_store=store)
|
||||
|
||||
events = list(node._run())
|
||||
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], PauseRequestedEvent)
|
||||
assert store.saved
|
||||
assert store.saved[0][1] == "fake-run-1"
|
||||
assert store.saved[0][3], "paused agent run should still persist replayable layer specs"
|
||||
|
||||
|
||||
def test_agent_node_records_stream_usage_metadata():
|
||||
metadata = {"agent_backend": {"run_id": "run-1"}}
|
||||
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
from dataclasses import replace
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from dify_agent.layers.dify_plugin import DifyPluginToolConfig, DifyPluginToolsLayerConfig
|
||||
from dify_agent.protocol import DIFY_AGENT_HISTORY_LAYER_ID, DIFY_AGENT_MODEL_LAYER_ID
|
||||
|
||||
from clients.agent_backend import DIFY_EXECUTION_CONTEXT_LAYER_ID, DIFY_PLUGIN_TOOLS_LAYER_ID
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom
|
||||
from core.workflow.nodes.agent_v2.plugin_tools_builder import WorkflowAgentPluginToolsBuilder
|
||||
from core.workflow.nodes.agent_v2.runtime_request_builder import (
|
||||
WorkflowAgentRuntimeBuildContext,
|
||||
WorkflowAgentRuntimeRequestBuilder,
|
||||
@ -27,6 +31,17 @@ class FakeCredentialsProvider:
|
||||
return {"api_key": "secret-key"}
|
||||
|
||||
|
||||
class CapturingCredentialsProvider:
|
||||
def __init__(self) -> None:
|
||||
self.provider_name: str | None = None
|
||||
self.model_name: str | None = None
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, object]:
|
||||
self.provider_name = provider_name
|
||||
self.model_name = model_name
|
||||
return {"api_key": "secret-key"}
|
||||
|
||||
|
||||
class FakePluginToolsBuilder:
|
||||
def __init__(self) -> None:
|
||||
# Capture the runtime invocation source so tests can assert it was
|
||||
@ -136,7 +151,31 @@ def test_builds_create_run_request_from_agent_soul_and_node_job():
|
||||
assert dumped["composition"]["layers"][1]["config"]["prefix"] == "Use the previous output."
|
||||
assert "Previous result" in dumped["composition"]["layers"][2]["config"]["user"]
|
||||
assert dumped["composition"]["layers"][-1]["config"]["json_schema"]["properties"]["summary"]["type"] == "string"
|
||||
assert result.redacted_request["composition"]["layers"][4]["config"]["credentials"] == "[REDACTED]"
|
||||
assert DIFY_AGENT_HISTORY_LAYER_ID in layers
|
||||
assert result.redacted_request["composition"]["layers"][5]["config"]["credentials"] == "[REDACTED]"
|
||||
|
||||
|
||||
def test_normalizes_langgenius_model_provider_for_agent_backend_transport():
|
||||
context = _context()
|
||||
context.snapshot.config_snapshot = AgentSoulConfig(
|
||||
prompt={"system_prompt": "You are careful."},
|
||||
model=AgentSoulModelConfig(
|
||||
plugin_id="langgenius/openai/openai",
|
||||
model_provider="langgenius/openai/openai",
|
||||
model="gpt-test",
|
||||
),
|
||||
)
|
||||
credentials_provider = CapturingCredentialsProvider()
|
||||
|
||||
result = WorkflowAgentRuntimeRequestBuilder(credentials_provider=credentials_provider).build(context)
|
||||
|
||||
dumped = result.request.model_dump(mode="json")
|
||||
layers = {layer["name"]: layer for layer in dumped["composition"]["layers"]}
|
||||
model_config = layers[DIFY_AGENT_MODEL_LAYER_ID]["config"]
|
||||
assert credentials_provider.provider_name == "langgenius/openai/openai"
|
||||
assert credentials_provider.model_name == "gpt-test"
|
||||
assert model_config["plugin_id"] == "langgenius/openai"
|
||||
assert model_config["model_provider"] == "openai"
|
||||
|
||||
|
||||
def test_builds_workflow_run_request_with_file_output_schema_and_reserved_metadata():
|
||||
@ -187,7 +226,7 @@ def test_builds_workflow_run_request_with_file_output_schema_and_reserved_metada
|
||||
assert output_schema["properties"]["report"]["properties"]["file_id"]["type"] == "string"
|
||||
assert output_schema["properties"]["confidence"]["type"] == "number"
|
||||
assert output_schema["required"] == ["report"]
|
||||
assert dumped["composition"]["layers"][4]["config"]["model_settings"] == {"temperature": 0.2}
|
||||
assert dumped["composition"]["layers"][5]["config"]["model_settings"] == {"temperature": 0.2}
|
||||
assert result.metadata["runtime_support"]["reserved_status"]["tools.dify_tools"] == "supported_when_config_valid"
|
||||
assert result.metadata["runtime_support"]["reserved_status"]["tools.cli_tools"] == "reserved_not_executed"
|
||||
warnings = result.metadata["runtime_support"]["unsupported_runtime_warnings"]
|
||||
@ -224,7 +263,7 @@ def test_builds_workflow_run_request_with_dify_plugin_tools_layer():
|
||||
plugin_tools_builder = FakePluginToolsBuilder()
|
||||
result = WorkflowAgentRuntimeRequestBuilder(
|
||||
credentials_provider=FakeCredentialsProvider(),
|
||||
plugin_tools_builder=plugin_tools_builder,
|
||||
plugin_tools_builder=cast(WorkflowAgentPluginToolsBuilder, plugin_tools_builder),
|
||||
).build(context)
|
||||
|
||||
dumped = result.request.model_dump(mode="json")
|
||||
@ -244,6 +283,15 @@ def test_builds_workflow_run_request_with_dify_plugin_tools_layer():
|
||||
assert plugin_tools_builder.last_invoke_from == context.dify_context.invoke_from
|
||||
|
||||
|
||||
def test_build_passes_saved_session_snapshot_to_agent_backend_request():
|
||||
session_snapshot = CompositorSessionSnapshot(layers=[])
|
||||
context = replace(_context(), session_snapshot=session_snapshot)
|
||||
|
||||
result = WorkflowAgentRuntimeRequestBuilder(credentials_provider=FakeCredentialsProvider()).build(context)
|
||||
|
||||
assert result.request.session_snapshot is session_snapshot
|
||||
|
||||
|
||||
def test_requires_agent_soul_model_config():
|
||||
context = _context()
|
||||
snapshot = AgentConfigSnapshot(
|
||||
|
||||
@ -0,0 +1,412 @@
|
||||
from datetime import UTC
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from agenton.compositor.schemas import LayerSessionSnapshot
|
||||
from agenton.layers.base import LifecycleState
|
||||
from dify_agent.protocol import CancelRunRequest, RunEvent, RunStatusResponse
|
||||
|
||||
from clients.agent_backend import AgentBackendRunRequestBuilder, CleanupLayerSpec, FakeAgentBackendRunClient
|
||||
from clients.agent_backend.errors import AgentBackendHTTPError
|
||||
from core.workflow.nodes.agent_v2.session_cleanup_layer import WorkflowAgentSessionCleanupLayer
|
||||
from core.workflow.nodes.agent_v2.session_store import (
|
||||
StoredWorkflowAgentSession,
|
||||
WorkflowAgentRuntimeSessionStore,
|
||||
WorkflowAgentSessionScope,
|
||||
)
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from graphon.entities.pause_reason import SchedulingPause
|
||||
from graphon.graph_engine.command_channels import CommandChannel
|
||||
from graphon.graph_events import (
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
|
||||
|
||||
|
||||
def _layer_snapshot(name: str) -> LayerSessionSnapshot:
|
||||
return LayerSessionSnapshot(
|
||||
name=name,
|
||||
lifecycle_state=LifecycleState.SUSPENDED,
|
||||
runtime_state={},
|
||||
)
|
||||
|
||||
|
||||
def _stored_session(scope: WorkflowAgentSessionScope, *, index: int = 1) -> StoredWorkflowAgentSession:
|
||||
"""A typical stored session with prompt + execution_context + history + llm specs.
|
||||
|
||||
The LLM layer is *not* in ``composition_layer_specs`` because the cleanup
|
||||
contract excludes credential-bearing plugin layers, but it *is* present in
|
||||
the saved snapshot so the layer's filter logic gets exercised.
|
||||
"""
|
||||
return StoredWorkflowAgentSession(
|
||||
scope=scope,
|
||||
session_snapshot=CompositorSessionSnapshot(
|
||||
layers=[
|
||||
_layer_snapshot("workflow_node_job_prompt"),
|
||||
_layer_snapshot("execution_context"),
|
||||
_layer_snapshot("history"),
|
||||
_layer_snapshot("llm"),
|
||||
]
|
||||
),
|
||||
backend_run_id=f"agent-run-{index}",
|
||||
composition_layer_specs=[
|
||||
CleanupLayerSpec(name="workflow_node_job_prompt", type="plain.prompt", config={"prefix": "ok"}),
|
||||
CleanupLayerSpec(name="execution_context", type="dify.execution_context", config={"tenant_id": "t"}),
|
||||
CleanupLayerSpec(name="history", type="pydantic_ai.history"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class FakeSessionStore:
|
||||
"""In-memory stand-in for ``WorkflowAgentRuntimeSessionStore``."""
|
||||
|
||||
def __init__(self, *, stored: list[StoredWorkflowAgentSession] | None = None) -> None:
|
||||
self._stored = stored if stored is not None else [_stored_session(_default_scope())]
|
||||
self.list_calls: list[str] = []
|
||||
self.cleaned: list[tuple[WorkflowAgentSessionScope, str | None]] = []
|
||||
|
||||
def list_active_sessions(self, *, workflow_run_id: str) -> list[StoredWorkflowAgentSession]:
|
||||
self.list_calls.append(workflow_run_id)
|
||||
return list(self._stored)
|
||||
|
||||
def mark_cleaned(self, *, scope: WorkflowAgentSessionScope, backend_run_id: str | None = None) -> None:
|
||||
self.cleaned.append((scope, backend_run_id))
|
||||
|
||||
|
||||
def _default_scope() -> WorkflowAgentSessionScope:
|
||||
return WorkflowAgentSessionScope(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
workflow_run_id="workflow-run-1",
|
||||
node_id="agent-node",
|
||||
node_execution_id="node-exec-1",
|
||||
binding_id="binding-1",
|
||||
agent_id="agent-1",
|
||||
agent_config_snapshot_id="snapshot-1",
|
||||
)
|
||||
|
||||
|
||||
class _WaitableFakeAgentBackendRunClient(FakeAgentBackendRunClient):
|
||||
"""``FakeAgentBackendRunClient`` plus the ``wait_run`` hook the layer needs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
run_id: str = "cleanup-run-1",
|
||||
wait_status: str = "succeeded",
|
||||
wait_error: str | None = None,
|
||||
wait_raises: Exception | None = None,
|
||||
) -> None:
|
||||
super().__init__(run_id=run_id)
|
||||
self._wait_status = wait_status
|
||||
self._wait_error = wait_error
|
||||
self._wait_raises = wait_raises
|
||||
self.wait_calls: list[tuple[str, float | None]] = []
|
||||
|
||||
def wait_run(self, run_id: str, *, timeout_seconds: float | None = None) -> RunStatusResponse:
|
||||
self.wait_calls.append((run_id, timeout_seconds))
|
||||
if self._wait_raises is not None:
|
||||
raise self._wait_raises
|
||||
from datetime import datetime
|
||||
|
||||
return RunStatusResponse(
|
||||
run_id=run_id,
|
||||
status=cast(object, self._wait_status), # protocol Literal; cast keeps tests flexible
|
||||
created_at=datetime(2026, 1, 1, tzinfo=UTC),
|
||||
updated_at=datetime(2026, 1, 1, tzinfo=UTC),
|
||||
error=self._wait_error,
|
||||
)
|
||||
|
||||
# Inherit ``create_run`` from FakeAgentBackendRunClient; the missing protocol
|
||||
# methods below are stub-only because the cleanup layer never calls them.
|
||||
def cancel_run(self, run_id: str, request: CancelRunRequest | None = None): # pragma: no cover
|
||||
del run_id, request
|
||||
raise NotImplementedError
|
||||
|
||||
def stream_events(self, run_id: str, *, after: str | None = None): # pragma: no cover
|
||||
del run_id, after
|
||||
if False:
|
||||
yield cast(RunEvent, None)
|
||||
|
||||
|
||||
def _build_layer(
|
||||
*,
|
||||
session_store: FakeSessionStore,
|
||||
agent_backend_client: _WaitableFakeAgentBackendRunClient,
|
||||
http_cleanup_supported: bool = True,
|
||||
) -> WorkflowAgentSessionCleanupLayer:
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="workflow-run-1"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
layer = WorkflowAgentSessionCleanupLayer(
|
||||
session_store=cast(WorkflowAgentRuntimeSessionStore, session_store),
|
||||
request_builder=AgentBackendRunRequestBuilder(),
|
||||
agent_backend_client=agent_backend_client,
|
||||
)
|
||||
# Tests opt in to the future HTTP-cleanup branch; the production default
|
||||
# (False) is exercised by the dedicated tests below.
|
||||
layer._HTTP_CLEANUP_SUPPORTED = http_cleanup_supported # type: ignore[reportPrivateUsage]
|
||||
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(runtime_state), cast(CommandChannel, object()))
|
||||
return layer
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"terminal_event",
|
||||
[
|
||||
GraphRunSucceededEvent(outputs={}),
|
||||
GraphRunPartialSucceededEvent(exceptions_count=1, outputs={}),
|
||||
GraphRunFailedEvent(error="boom"),
|
||||
GraphRunAbortedEvent(reason="user cancelled", outputs={}),
|
||||
],
|
||||
ids=["succeeded", "partial_succeeded", "failed", "aborted"],
|
||||
)
|
||||
def test_cleanup_layer_triggers_cleanup_only_run_on_each_terminal_event(terminal_event):
|
||||
session_store = FakeSessionStore()
|
||||
agent_backend_client = _WaitableFakeAgentBackendRunClient()
|
||||
layer = _build_layer(session_store=session_store, agent_backend_client=agent_backend_client)
|
||||
|
||||
layer.on_event(terminal_event)
|
||||
|
||||
assert session_store.list_calls == ["workflow-run-1"]
|
||||
assert agent_backend_client.request is not None
|
||||
# Cleanup composition replays the persisted (non-plugin) layer specs so the
|
||||
# agent backend's snapshot-vs-composition name match succeeds.
|
||||
layer_names = [layer.name for layer in agent_backend_client.request.composition.layers]
|
||||
assert layer_names == ["workflow_node_job_prompt", "execution_context", "history"]
|
||||
assert agent_backend_client.request.on_exit.default.value == "delete"
|
||||
assert agent_backend_client.request.metadata["agent_backend_lifecycle"] == "session_cleanup"
|
||||
# Snapshot is filtered to drop the plugin layer entry so names match the
|
||||
# cleanup composition.
|
||||
assert agent_backend_client.request.session_snapshot is not None
|
||||
snapshot_names = [layer.name for layer in agent_backend_client.request.session_snapshot.layers]
|
||||
assert snapshot_names == ["workflow_node_job_prompt", "execution_context", "history"]
|
||||
# The layer waited for terminal status and the run succeeded, so the row
|
||||
# is marked CLEANED with the cleanup run id.
|
||||
assert agent_backend_client.wait_calls
|
||||
assert session_store.cleaned == [(_default_scope(), "cleanup-run-1")]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"non_terminal_event",
|
||||
[
|
||||
GraphRunStartedEvent(),
|
||||
GraphRunPausedEvent(reasons=[SchedulingPause(message="awaiting human input")], outputs={}),
|
||||
],
|
||||
ids=["started", "paused"],
|
||||
)
|
||||
def test_cleanup_layer_ignores_non_terminal_events(non_terminal_event):
|
||||
session_store = FakeSessionStore()
|
||||
agent_backend_client = _WaitableFakeAgentBackendRunClient()
|
||||
layer = _build_layer(session_store=session_store, agent_backend_client=agent_backend_client)
|
||||
|
||||
layer.on_event(non_terminal_event)
|
||||
|
||||
assert session_store.list_calls == []
|
||||
assert agent_backend_client.request is None
|
||||
assert session_store.cleaned == []
|
||||
|
||||
|
||||
def test_cleanup_layer_does_not_mark_cleaned_when_cleanup_run_fails():
|
||||
"""Trap D: cleanup-only run goes ``run_failed`` (e.g. snapshot validation
|
||||
error) — the layer must leave the row ACTIVE so it can be retried instead
|
||||
of silently leaking suspended agent-backend layers."""
|
||||
session_store = FakeSessionStore()
|
||||
agent_backend_client = _WaitableFakeAgentBackendRunClient(
|
||||
wait_status="failed",
|
||||
wait_error="snapshot mismatch",
|
||||
)
|
||||
layer = _build_layer(session_store=session_store, agent_backend_client=agent_backend_client)
|
||||
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={}))
|
||||
|
||||
assert agent_backend_client.wait_calls
|
||||
assert session_store.cleaned == []
|
||||
|
||||
|
||||
def test_cleanup_layer_does_not_mark_cleaned_when_wait_raises():
|
||||
session_store = FakeSessionStore()
|
||||
agent_backend_client = _WaitableFakeAgentBackendRunClient(
|
||||
wait_raises=AgentBackendHTTPError("boom", status_code=500, detail=None),
|
||||
)
|
||||
layer = _build_layer(session_store=session_store, agent_backend_client=agent_backend_client)
|
||||
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={}))
|
||||
|
||||
assert session_store.cleaned == []
|
||||
|
||||
|
||||
def test_cleanup_layer_marks_cleaned_locally_when_http_cleanup_disabled():
|
||||
"""Production default: dify-agent has no cleanup-only run mode yet, so the
|
||||
layer must retire the local row without issuing a doomed HTTP request that
|
||||
would crash inside the agent backend's runner on the missing LLM layer."""
|
||||
session_store = FakeSessionStore()
|
||||
agent_backend_client = _WaitableFakeAgentBackendRunClient()
|
||||
layer = _build_layer(
|
||||
session_store=session_store,
|
||||
agent_backend_client=agent_backend_client,
|
||||
http_cleanup_supported=False,
|
||||
)
|
||||
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={}))
|
||||
|
||||
# No HTTP call goes out — the trap is avoided entirely.
|
||||
assert agent_backend_client.request is None
|
||||
assert agent_backend_client.wait_calls == []
|
||||
# Local row is still retired so a workflow loop cannot resume from stale state.
|
||||
assert session_store.cleaned == [(_default_scope(), "agent-run-1")]
|
||||
|
||||
|
||||
def test_cleanup_layer_skips_sessions_without_persisted_specs():
|
||||
"""Backwards-compatible safety net: a row written before A.1 landed has
|
||||
no composition_layer_specs, so cleanup would unavoidably hit the snapshot-
|
||||
validation trap. The layer must skip such rows instead of issuing a
|
||||
doomed request."""
|
||||
scope = _default_scope()
|
||||
legacy_session = StoredWorkflowAgentSession(
|
||||
scope=scope,
|
||||
session_snapshot=CompositorSessionSnapshot(layers=[_layer_snapshot("history")]),
|
||||
backend_run_id="legacy-run",
|
||||
composition_layer_specs=[],
|
||||
)
|
||||
session_store = FakeSessionStore(stored=[legacy_session])
|
||||
agent_backend_client = _WaitableFakeAgentBackendRunClient()
|
||||
layer = _build_layer(session_store=session_store, agent_backend_client=agent_backend_client)
|
||||
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={}))
|
||||
|
||||
assert agent_backend_client.request is None
|
||||
assert session_store.cleaned == []
|
||||
|
||||
|
||||
def test_cleanup_layer_fans_out_to_every_active_session():
|
||||
scopes = [
|
||||
WorkflowAgentSessionScope(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
workflow_run_id="workflow-run-1",
|
||||
node_id=f"agent-node-{i}",
|
||||
node_execution_id=f"node-exec-{i}",
|
||||
binding_id=f"binding-{i}",
|
||||
agent_id=f"agent-{i}",
|
||||
agent_config_snapshot_id=f"snapshot-{i}",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
session_store = FakeSessionStore(stored=[_stored_session(scope, index=i) for i, scope in enumerate(scopes, 1)])
|
||||
agent_backend_client = _WaitableFakeAgentBackendRunClient(run_id="cleanup-run-many")
|
||||
layer = _build_layer(session_store=session_store, agent_backend_client=agent_backend_client)
|
||||
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={}))
|
||||
|
||||
# One cleanup row per stored ACTIVE session, all marked cleaned with the
|
||||
# backend run id returned by the agent backend client.
|
||||
assert [entry[0] for entry in session_store.cleaned] == scopes
|
||||
assert {entry[1] for entry in session_store.cleaned} == {"cleanup-run-many"}
|
||||
|
||||
|
||||
def test_cleanup_layer_warns_when_http_enabled_but_client_missing(caplog):
|
||||
"""The HTTP cleanup branch must defensively skip when no client was wired.
|
||||
|
||||
This is the deployment-misconfig path: ``_HTTP_CLEANUP_SUPPORTED`` was
|
||||
flipped to ``True`` but ``AGENT_BACKEND_BASE_URL`` is unset, so the
|
||||
factory returned ``None``. The layer must not crash and must not silently
|
||||
retire the row — the warning surfaces the misconfig.
|
||||
"""
|
||||
import logging
|
||||
|
||||
session_store = FakeSessionStore()
|
||||
layer = WorkflowAgentSessionCleanupLayer(
|
||||
session_store=cast(WorkflowAgentRuntimeSessionStore, session_store),
|
||||
request_builder=AgentBackendRunRequestBuilder(),
|
||||
agent_backend_client=None,
|
||||
)
|
||||
layer._HTTP_CLEANUP_SUPPORTED = True # type: ignore[reportPrivateUsage]
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="workflow-run-1"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(runtime_state), cast(CommandChannel, object()))
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={}))
|
||||
|
||||
assert session_store.cleaned == []
|
||||
assert any("no agent backend client is wired in" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
def test_cleanup_layer_skips_workflow_terminal_when_workflow_run_id_missing(caplog):
|
||||
"""``workflow_run_id`` is the keying field; without it the fanout cannot
|
||||
target a row, so the layer logs a warning and bails."""
|
||||
import logging
|
||||
|
||||
session_store = FakeSessionStore()
|
||||
agent_backend_client = _WaitableFakeAgentBackendRunClient()
|
||||
layer = WorkflowAgentSessionCleanupLayer(
|
||||
session_store=cast(WorkflowAgentRuntimeSessionStore, session_store),
|
||||
request_builder=AgentBackendRunRequestBuilder(),
|
||||
agent_backend_client=agent_backend_client,
|
||||
)
|
||||
# Bootstrap *without* a workflow_execution_id system variable.
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id=""),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(runtime_state), cast(CommandChannel, object()))
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={}))
|
||||
|
||||
assert session_store.list_calls == []
|
||||
assert session_store.cleaned == []
|
||||
assert any("workflow_run_id is missing" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
def test_build_workflow_agent_session_cleanup_layer_returns_layer_without_client_when_unconfigured(
|
||||
monkeypatch,
|
||||
):
|
||||
"""The production builder must pass ``None`` for the agent backend client
|
||||
when neither AGENT_BACKEND_BASE_URL nor AGENT_BACKEND_USE_FAKE is set, so
|
||||
that unit-test environments without backend config don't crash at runner
|
||||
construction."""
|
||||
from configs import dify_config
|
||||
from core.workflow.nodes.agent_v2.session_cleanup_layer import (
|
||||
build_workflow_agent_session_cleanup_layer,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(dify_config, "AGENT_BACKEND_BASE_URL", None, raising=False)
|
||||
monkeypatch.setattr(dify_config, "AGENT_BACKEND_USE_FAKE", False, raising=False)
|
||||
|
||||
layer = build_workflow_agent_session_cleanup_layer()
|
||||
assert layer._agent_backend_client is None # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
def test_build_workflow_agent_session_cleanup_layer_returns_layer_with_fake_client(monkeypatch):
|
||||
"""With ``AGENT_BACKEND_USE_FAKE`` enabled the helper wires in the
|
||||
deterministic fake client without needing a base_url."""
|
||||
from clients.agent_backend.fake_client import FakeAgentBackendRunClient
|
||||
from configs import dify_config
|
||||
from core.workflow.nodes.agent_v2.session_cleanup_layer import (
|
||||
build_workflow_agent_session_cleanup_layer,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(dify_config, "AGENT_BACKEND_BASE_URL", None, raising=False)
|
||||
monkeypatch.setattr(dify_config, "AGENT_BACKEND_USE_FAKE", True, raising=False)
|
||||
monkeypatch.setattr(dify_config, "AGENT_BACKEND_FAKE_SCENARIO", "success", raising=False)
|
||||
|
||||
layer = build_workflow_agent_session_cleanup_layer()
|
||||
assert isinstance(layer._agent_backend_client, FakeAgentBackendRunClient) # type: ignore[reportPrivateUsage]
|
||||
@ -0,0 +1,286 @@
|
||||
"""Unit tests for :mod:`core.workflow.nodes.agent_v2.session_store`.
|
||||
|
||||
Uses the in-memory SQLite engine configured by the project conftest plus a
|
||||
per-test ``CREATE TABLE`` so the real ORM round-trip exercises every store
|
||||
method. Keeps the suite self-contained — no Postgres / Docker required — while
|
||||
still hitting the actual ``session_factory`` code path that production uses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from agenton.compositor.schemas import LayerSessionSnapshot
|
||||
from agenton.layers.base import LifecycleState
|
||||
from sqlalchemy import delete
|
||||
|
||||
from clients.agent_backend.request_builder import CleanupLayerSpec
|
||||
from core.db.session_factory import session_factory
|
||||
from core.workflow.nodes.agent_v2.session_store import (
|
||||
StoredWorkflowAgentSession,
|
||||
WorkflowAgentRuntimeSessionStore,
|
||||
WorkflowAgentSessionScope,
|
||||
)
|
||||
from models.agent import WorkflowAgentRuntimeSession, WorkflowAgentRuntimeSessionStatus
|
||||
|
||||
|
||||
def _scope(workflow_run_id: str | None = "wfr-1", binding_id: str = "binding-1") -> WorkflowAgentSessionScope:
|
||||
return WorkflowAgentSessionScope(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id="agent-node",
|
||||
node_execution_id="node-exec-1",
|
||||
binding_id=binding_id,
|
||||
agent_id="agent-1",
|
||||
agent_config_snapshot_id="snapshot-1",
|
||||
)
|
||||
|
||||
|
||||
def _snapshot(messages: int = 1) -> CompositorSessionSnapshot:
|
||||
return CompositorSessionSnapshot(
|
||||
layers=[
|
||||
LayerSessionSnapshot(
|
||||
name="history",
|
||||
lifecycle_state=LifecycleState.SUSPENDED,
|
||||
runtime_state={"messages": [{"role": "user", "content": f"m{i}"} for i in range(messages)]},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _specs() -> list[CleanupLayerSpec]:
|
||||
return [
|
||||
CleanupLayerSpec(name="workflow_node_job_prompt", type="plain.prompt", config={"prefix": "ok"}),
|
||||
CleanupLayerSpec(name="history", type="pydantic_ai.history"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _create_table() -> Generator[None, None, None]:
|
||||
"""Create the lifecycle table on the in-memory SQLite engine, drop after."""
|
||||
engine = session_factory.get_session_maker().kw["bind"]
|
||||
WorkflowAgentRuntimeSession.__table__.create(bind=engine, checkfirst=True)
|
||||
yield
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(WorkflowAgentRuntimeSession))
|
||||
session.commit()
|
||||
WorkflowAgentRuntimeSession.__table__.drop(bind=engine, checkfirst=True)
|
||||
|
||||
|
||||
def test_load_active_snapshot_returns_none_when_scope_has_no_workflow_run_id():
|
||||
"""``workflow_run_id`` is the keying column; no row can match without it."""
|
||||
store = WorkflowAgentRuntimeSessionStore()
|
||||
assert store.load_active_snapshot(_scope(workflow_run_id=None)) is None
|
||||
|
||||
|
||||
def test_load_active_snapshot_returns_none_when_no_row_matches():
|
||||
store = WorkflowAgentRuntimeSessionStore()
|
||||
assert store.load_active_snapshot(_scope()) is None
|
||||
|
||||
|
||||
def test_save_active_snapshot_creates_row_and_load_round_trips():
|
||||
store = WorkflowAgentRuntimeSessionStore()
|
||||
snapshot = _snapshot(messages=2)
|
||||
store.save_active_snapshot(
|
||||
scope=_scope(), backend_run_id="run-1", snapshot=snapshot, composition_layer_specs=_specs()
|
||||
)
|
||||
|
||||
loaded = store.load_active_snapshot(_scope())
|
||||
assert loaded is not None
|
||||
assert len(loaded.layers) == 1
|
||||
assert loaded.layers[0].name == "history"
|
||||
assert loaded.layers[0].runtime_state["messages"] == snapshot.layers[0].runtime_state["messages"]
|
||||
|
||||
|
||||
def test_save_active_snapshot_skips_when_workflow_run_id_missing():
|
||||
"""Without a workflow_run_id the row cannot be keyed; save is a no-op."""
|
||||
store = WorkflowAgentRuntimeSessionStore()
|
||||
store.save_active_snapshot(
|
||||
scope=_scope(workflow_run_id=None),
|
||||
backend_run_id="run-skipped",
|
||||
snapshot=_snapshot(),
|
||||
composition_layer_specs=_specs(),
|
||||
)
|
||||
with session_factory.create_session() as session:
|
||||
assert session.query(WorkflowAgentRuntimeSession).count() == 0
|
||||
|
||||
|
||||
def test_save_active_snapshot_skips_when_snapshot_missing():
|
||||
"""A run that produced no snapshot (e.g. failed agent run) does not write."""
|
||||
store = WorkflowAgentRuntimeSessionStore()
|
||||
store.save_active_snapshot(
|
||||
scope=_scope(),
|
||||
backend_run_id="run-empty",
|
||||
snapshot=None,
|
||||
composition_layer_specs=_specs(),
|
||||
)
|
||||
with session_factory.create_session() as session:
|
||||
assert session.query(WorkflowAgentRuntimeSession).count() == 0
|
||||
|
||||
|
||||
def test_save_active_snapshot_updates_existing_row_on_re_entry():
|
||||
"""A second save under the same scope must update in place, not insert."""
|
||||
store = WorkflowAgentRuntimeSessionStore()
|
||||
store.save_active_snapshot(
|
||||
scope=_scope(),
|
||||
backend_run_id="run-1",
|
||||
snapshot=_snapshot(messages=1),
|
||||
composition_layer_specs=_specs(),
|
||||
)
|
||||
# Second call with new snapshot + backend_run_id.
|
||||
store.save_active_snapshot(
|
||||
scope=_scope(),
|
||||
backend_run_id="run-2",
|
||||
snapshot=_snapshot(messages=2),
|
||||
composition_layer_specs=_specs(),
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
rows = session.query(WorkflowAgentRuntimeSession).all()
|
||||
assert len(rows) == 1
|
||||
assert rows[0].backend_run_id == "run-2"
|
||||
assert rows[0].status == WorkflowAgentRuntimeSessionStatus.ACTIVE
|
||||
assert rows[0].cleaned_at is None
|
||||
|
||||
|
||||
def test_save_active_snapshot_resurrects_cleaned_row():
|
||||
"""If a prior cleanup retired the row, a re-entry flips it back to ACTIVE."""
|
||||
store = WorkflowAgentRuntimeSessionStore()
|
||||
store.save_active_snapshot(
|
||||
scope=_scope(),
|
||||
backend_run_id="run-1",
|
||||
snapshot=_snapshot(),
|
||||
composition_layer_specs=_specs(),
|
||||
)
|
||||
store.mark_cleaned(scope=_scope(), backend_run_id="cleanup-1")
|
||||
# Save again — the existing row was CLEANED; should be revived.
|
||||
store.save_active_snapshot(
|
||||
scope=_scope(),
|
||||
backend_run_id="run-2",
|
||||
snapshot=_snapshot(messages=3),
|
||||
composition_layer_specs=_specs(),
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
rows = session.query(WorkflowAgentRuntimeSession).all()
|
||||
assert len(rows) == 1
|
||||
assert rows[0].status == WorkflowAgentRuntimeSessionStatus.ACTIVE
|
||||
assert rows[0].cleaned_at is None
|
||||
assert rows[0].backend_run_id == "run-2"
|
||||
|
||||
|
||||
def test_list_active_sessions_returns_specs_and_snapshot():
|
||||
store = WorkflowAgentRuntimeSessionStore()
|
||||
store.save_active_snapshot(
|
||||
scope=_scope(binding_id="binding-A"),
|
||||
backend_run_id="run-A",
|
||||
snapshot=_snapshot(),
|
||||
composition_layer_specs=_specs(),
|
||||
)
|
||||
store.save_active_snapshot(
|
||||
scope=_scope(binding_id="binding-B"),
|
||||
backend_run_id="run-B",
|
||||
snapshot=_snapshot(messages=2),
|
||||
composition_layer_specs=_specs(),
|
||||
)
|
||||
|
||||
listed = store.list_active_sessions(workflow_run_id="wfr-1")
|
||||
assert {s.backend_run_id for s in listed} == {"run-A", "run-B"}
|
||||
by_run = {s.backend_run_id: s for s in listed}
|
||||
assert isinstance(by_run["run-A"], StoredWorkflowAgentSession)
|
||||
# Specs round-trip through pydantic TypeAdapter — ensure deserialize works.
|
||||
assert by_run["run-A"].composition_layer_specs[0].name == "workflow_node_job_prompt"
|
||||
assert by_run["run-A"].composition_layer_specs[1].type == "pydantic_ai.history"
|
||||
# node_execution_id default-replaces NULL with "" when the DB column is None.
|
||||
assert by_run["run-A"].scope.node_execution_id == "node-exec-1"
|
||||
|
||||
|
||||
def test_list_active_sessions_skips_cleaned_rows():
|
||||
store = WorkflowAgentRuntimeSessionStore()
|
||||
store.save_active_snapshot(
|
||||
scope=_scope(binding_id="binding-A"),
|
||||
backend_run_id="run-A",
|
||||
snapshot=_snapshot(),
|
||||
composition_layer_specs=_specs(),
|
||||
)
|
||||
store.save_active_snapshot(
|
||||
scope=_scope(binding_id="binding-B"),
|
||||
backend_run_id="run-B",
|
||||
snapshot=_snapshot(),
|
||||
composition_layer_specs=_specs(),
|
||||
)
|
||||
store.mark_cleaned(scope=_scope(binding_id="binding-A"), backend_run_id="cleanup-A")
|
||||
|
||||
listed = store.list_active_sessions(workflow_run_id="wfr-1")
|
||||
assert {s.backend_run_id for s in listed} == {"run-B"}
|
||||
|
||||
|
||||
def test_list_active_sessions_handles_legacy_rows_without_specs():
|
||||
"""Rows persisted before composition_layer_specs landed have an empty string."""
|
||||
# Insert a legacy-shape row directly: empty specs payload simulates a row
|
||||
# written before the spec persistence feature landed in A.1.
|
||||
store = WorkflowAgentRuntimeSessionStore()
|
||||
store.save_active_snapshot(
|
||||
scope=_scope(),
|
||||
backend_run_id="run-legacy",
|
||||
snapshot=_snapshot(),
|
||||
composition_layer_specs=[],
|
||||
)
|
||||
listed = store.list_active_sessions(workflow_run_id="wfr-1")
|
||||
assert len(listed) == 1
|
||||
assert listed[0].composition_layer_specs == []
|
||||
|
||||
|
||||
def test_mark_cleaned_sets_status_and_cleaned_at_with_backend_run_id():
|
||||
store = WorkflowAgentRuntimeSessionStore()
|
||||
store.save_active_snapshot(
|
||||
scope=_scope(),
|
||||
backend_run_id="run-1",
|
||||
snapshot=_snapshot(),
|
||||
composition_layer_specs=_specs(),
|
||||
)
|
||||
store.mark_cleaned(scope=_scope(), backend_run_id="cleanup-1")
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
row = session.query(WorkflowAgentRuntimeSession).one()
|
||||
assert row.status == WorkflowAgentRuntimeSessionStatus.CLEANED
|
||||
assert row.cleaned_at is not None
|
||||
assert row.backend_run_id == "cleanup-1"
|
||||
|
||||
|
||||
def test_mark_cleaned_preserves_existing_backend_run_id_when_none_given():
|
||||
"""``backend_run_id=None`` means "leave the previous one in place"."""
|
||||
store = WorkflowAgentRuntimeSessionStore()
|
||||
store.save_active_snapshot(
|
||||
scope=_scope(),
|
||||
backend_run_id="run-1",
|
||||
snapshot=_snapshot(),
|
||||
composition_layer_specs=_specs(),
|
||||
)
|
||||
store.mark_cleaned(scope=_scope(), backend_run_id=None)
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
row = session.query(WorkflowAgentRuntimeSession).one()
|
||||
assert row.status == WorkflowAgentRuntimeSessionStatus.CLEANED
|
||||
assert row.backend_run_id == "run-1"
|
||||
|
||||
|
||||
def test_mark_cleaned_is_a_noop_when_no_active_row():
|
||||
"""No matching ACTIVE row → no-op (already-cleaned rows are not re-touched)."""
|
||||
store = WorkflowAgentRuntimeSessionStore()
|
||||
store.mark_cleaned(scope=_scope(), backend_run_id="cleanup-1")
|
||||
with session_factory.create_session() as session:
|
||||
assert session.query(WorkflowAgentRuntimeSession).count() == 0
|
||||
|
||||
|
||||
def test_mark_cleaned_is_a_noop_when_workflow_run_id_missing():
|
||||
"""Without a workflow_run_id we cannot key the row; ignore the call."""
|
||||
store = WorkflowAgentRuntimeSessionStore()
|
||||
store.mark_cleaned(scope=_scope(workflow_run_id=None), backend_run_id="cleanup-1")
|
||||
# Sanity — no rows created or touched.
|
||||
with session_factory.create_session() as session:
|
||||
assert session.query(WorkflowAgentRuntimeSession).count() == 0
|
||||
@ -11,6 +11,7 @@ from libs.oauth_bearer import (
|
||||
SubjectType,
|
||||
TokenKind,
|
||||
TokenKindRegistry,
|
||||
TokenType,
|
||||
)
|
||||
|
||||
|
||||
@ -21,7 +22,7 @@ def _registry_with_resolver(resolver) -> TokenKindRegistry:
|
||||
prefix="dfoa_",
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
source="oauth_account",
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
resolver=resolver,
|
||||
)
|
||||
]
|
||||
@ -63,7 +64,7 @@ def test_unknown_prefix_raises_generic_invalid_bearer():
|
||||
prefix="dfoa_",
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
source="oauth_account",
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
resolver=MagicMock(),
|
||||
)
|
||||
]
|
||||
|
||||
@ -19,6 +19,7 @@ from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
Scope,
|
||||
SubjectType,
|
||||
TokenType,
|
||||
require_scope,
|
||||
reset_auth_ctx,
|
||||
set_auth_ctx,
|
||||
@ -50,7 +51,7 @@ def _ctx(scopes) -> AuthContext:
|
||||
client_id="difyctl",
|
||||
scopes=scopes,
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_account",
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
|
||||
@ -8,7 +8,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType, require_workspace_member
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, require_workspace_member
|
||||
|
||||
|
||||
def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> AuthContext:
|
||||
@ -20,7 +20,7 @@ def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> Au
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_account",
|
||||
token_type=TokenType.OAUTH_ACCOUNT if account else TokenType.OAUTH_EXTERNAL_SSO,
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants=dict(verified or {}),
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT, AuthContext, SubjectType
|
||||
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT, AuthContext, SubjectType, TokenType
|
||||
from services.oauth_device_flow import (
|
||||
list_active_sessions,
|
||||
revoke_oauth_token,
|
||||
@ -21,7 +21,7 @@ def _account_ctx() -> AuthContext:
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({"full"}),
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_account",
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
@ -37,7 +37,7 @@ def _sso_ctx() -> AuthContext:
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({"apps:run"}),
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_external_sso",
|
||||
token_type=TokenType.OAUTH_EXTERNAL_SSO,
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
|
||||
@ -1,115 +0,0 @@
|
||||
"""
|
||||
Unit tests for delete_account_task.
|
||||
|
||||
Covers:
|
||||
- Billing enabled with existing account: calls billing and sends success email
|
||||
- Billing disabled with existing account: skips billing, sends success email
|
||||
- Account not found: still calls billing when enabled, does not send email
|
||||
- Billing deletion raises: logs and re-raises, no email
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tasks.delete_account_task import delete_account_task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock session via session_factory.create_session()."""
|
||||
with patch("tasks.delete_account_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
cm.__exit__.return_value = None
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_deps():
|
||||
"""Patch external dependencies: BillingService and send_deletion_success_task."""
|
||||
with (
|
||||
patch("tasks.delete_account_task.BillingService") as mock_billing,
|
||||
patch("tasks.delete_account_task.send_deletion_success_task") as mock_mail_task,
|
||||
):
|
||||
# ensure .delay exists on the mail task
|
||||
mock_mail_task.delay = MagicMock()
|
||||
yield {
|
||||
"billing": mock_billing,
|
||||
"mail_task": mock_mail_task,
|
||||
}
|
||||
|
||||
|
||||
def _set_account_found(mock_db_session, email: str = "user@example.com"):
|
||||
account = SimpleNamespace(email=email)
|
||||
mock_db_session.scalar.return_value = account
|
||||
return account
|
||||
|
||||
|
||||
def _set_account_missing(mock_db_session):
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
|
||||
class TestDeleteAccountTask:
|
||||
def test_billing_enabled_account_exists_calls_billing_and_sends_email(self, mock_db_session, mock_deps):
|
||||
# Arrange
|
||||
account_id = "acc-123"
|
||||
account = _set_account_found(mock_db_session, email="a@b.com")
|
||||
|
||||
# Enable billing
|
||||
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True):
|
||||
# Act
|
||||
delete_account_task(account_id)
|
||||
|
||||
# Assert
|
||||
mock_deps["billing"].delete_account.assert_called_once_with(account_id)
|
||||
mock_deps["mail_task"].delay.assert_called_once_with(account.email)
|
||||
|
||||
def test_billing_disabled_account_exists_sends_email_only(self, mock_db_session, mock_deps):
|
||||
# Arrange
|
||||
account_id = "acc-456"
|
||||
account = _set_account_found(mock_db_session, email="x@y.com")
|
||||
|
||||
# Disable billing
|
||||
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", False):
|
||||
# Act
|
||||
delete_account_task(account_id)
|
||||
|
||||
# Assert
|
||||
mock_deps["billing"].delete_account.assert_not_called()
|
||||
mock_deps["mail_task"].delay.assert_called_once_with(account.email)
|
||||
|
||||
def test_account_not_found_billing_enabled_calls_billing_no_email(self, mock_db_session, mock_deps, caplog):
|
||||
# Arrange
|
||||
account_id = "missing-id"
|
||||
_set_account_missing(mock_db_session)
|
||||
|
||||
# Enable billing
|
||||
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True):
|
||||
# Act
|
||||
delete_account_task(account_id)
|
||||
|
||||
# Assert
|
||||
mock_deps["billing"].delete_account.assert_called_once_with(account_id)
|
||||
mock_deps["mail_task"].delay.assert_not_called()
|
||||
# Optional: verify log contains not found message
|
||||
assert any("not found" in rec.getMessage().lower() for rec in caplog.records)
|
||||
|
||||
def test_billing_delete_raises_propagates_and_no_email(self, mock_db_session, mock_deps):
|
||||
# Arrange
|
||||
account_id = "acc-err"
|
||||
_set_account_found(mock_db_session, email="err@ex.com")
|
||||
mock_deps["billing"].delete_account.side_effect = RuntimeError("billing down")
|
||||
|
||||
# Enable billing
|
||||
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True):
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError):
|
||||
delete_account_task(account_id)
|
||||
|
||||
# Ensure email was not sent
|
||||
mock_deps["mail_task"].delay.assert_not_called()
|
||||
@ -0,0 +1,19 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { redirect } from '@/next/navigation'
|
||||
import InstalledApp from '../page'
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
redirect: vi.fn((path: string) => {
|
||||
throw new Error(`redirect:${path}`)
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('legacy installed app route', () => {
|
||||
it('redirects to the canonical installed app route', async () => {
|
||||
await expect(InstalledApp({
|
||||
params: Promise.resolve({ appId: 'installed-1' }),
|
||||
})).rejects.toThrow('redirect:/installed/installed-1')
|
||||
|
||||
expect(redirect).toHaveBeenCalledWith('/installed/installed-1')
|
||||
})
|
||||
})
|
||||
@ -1,5 +1,5 @@
|
||||
import * as React from 'react'
|
||||
import Main from '@/app/components/explore/installed-app'
|
||||
import { buildInstalledAppPath } from '@/app/components/explore/installed-app/routes'
|
||||
import { redirect } from '@/next/navigation'
|
||||
|
||||
export type IInstalledAppProps = {
|
||||
params?: Promise<{
|
||||
@ -10,9 +10,7 @@ export type IInstalledAppProps = {
|
||||
// Using Next.js page convention for async server components
|
||||
async function InstalledApp({ params }: IInstalledAppProps) {
|
||||
const { appId } = await (params ?? Promise.reject(new Error('Missing params')))
|
||||
return (
|
||||
<Main id={appId} />
|
||||
)
|
||||
redirect(buildInstalledAppPath(appId))
|
||||
}
|
||||
|
||||
export default InstalledApp
|
||||
|
||||
@ -0,0 +1,23 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import InstalledApp from '../page'
|
||||
|
||||
vi.mock('@/app/components/explore/installed-app', () => ({
|
||||
default: ({ id }: { id: string }) => (
|
||||
<div data-testid="installed-app-page">
|
||||
{id}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
describe('installed app route', () => {
|
||||
it('should render the installed app page with the route app id', async () => {
|
||||
const page = await InstalledApp({
|
||||
params: Promise.resolve({ appId: 'installed-1' }),
|
||||
})
|
||||
|
||||
render(page)
|
||||
|
||||
expect(screen.getByTestId('installed-app-page')).toHaveTextContent('installed-1')
|
||||
})
|
||||
})
|
||||
18
web/app/(commonLayout)/installed/[appId]/page.tsx
Normal file
18
web/app/(commonLayout)/installed/[appId]/page.tsx
Normal file
@ -0,0 +1,18 @@
|
||||
import * as React from 'react'
|
||||
import Main from '@/app/components/explore/installed-app'
|
||||
|
||||
export type IInstalledAppProps = {
|
||||
params?: Promise<{
|
||||
appId: string
|
||||
}>
|
||||
}
|
||||
|
||||
// Using Next.js page convention for async server components
|
||||
async function InstalledApp({ params }: IInstalledAppProps) {
|
||||
const { appId } = await (params ?? Promise.reject(new Error('Missing params')))
|
||||
return (
|
||||
<Main id={appId} />
|
||||
)
|
||||
}
|
||||
|
||||
export default InstalledApp
|
||||
@ -2,7 +2,7 @@ import Loading from '@/app/components/base/loading'
|
||||
|
||||
export default function CommonLayoutLoading() {
|
||||
return (
|
||||
<div className="flex h-screen w-screen items-center justify-center bg-background-body">
|
||||
<div className="flex min-h-0 w-full flex-1 items-center justify-center bg-background-body">
|
||||
<Loading />
|
||||
</div>
|
||||
)
|
||||
|
||||
12
web/app/(commonLayout)/marketplace/layout.tsx
Normal file
12
web/app/(commonLayout)/marketplace/layout.tsx
Normal file
@ -0,0 +1,12 @@
|
||||
'use client'
|
||||
|
||||
import type { PropsWithChildren } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
|
||||
export default function MarketplaceLayout({ children }: PropsWithChildren) {
|
||||
const { t } = useTranslation()
|
||||
useDocumentTitle(t('mainNav.marketplace', { ns: 'common' }))
|
||||
|
||||
return children
|
||||
}
|
||||
@ -60,7 +60,6 @@ describe('RoleRouteGuard', () => {
|
||||
it.each([
|
||||
'/',
|
||||
'/apps',
|
||||
'/roster',
|
||||
'/tools',
|
||||
'/integrations/model-provider',
|
||||
])('should redirect dataset operator on guarded route %s', async (pathname) => {
|
||||
|
||||
@ -6,7 +6,7 @@ import Loading from '@/app/components/base/loading'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { usePathname, useRouter } from '@/next/navigation'
|
||||
|
||||
const datasetOperatorRedirectRoutes = ['/', '/apps', '/app', '/roster', '/explore', '/tools', '/integrations'] as const
|
||||
const datasetOperatorRedirectRoutes = ['/', '/apps', '/app', '/explore', '/tools', '/integrations'] as const
|
||||
|
||||
const isPathUnderRoute = (pathname: string, route: string) => pathname === route || pathname.startsWith(`${route}/`)
|
||||
|
||||
|
||||
@ -1,5 +0,0 @@
|
||||
import { AgentDetailPage } from '@/features/agent-v2/pages/agent-detail-page'
|
||||
|
||||
export default function Page() {
|
||||
return <AgentDetailPage section="access" />
|
||||
}
|
||||
@ -1,5 +0,0 @@
|
||||
import { AgentDetailPage } from '@/features/agent-v2/pages/agent-detail-page'
|
||||
|
||||
export default function Page() {
|
||||
return <AgentDetailPage section="configure" />
|
||||
}
|
||||
@ -1,20 +0,0 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { AgentDetailLayout } from '@/features/agent-v2/layouts/agent-detail-layout'
|
||||
|
||||
type LayoutProps = {
|
||||
children: ReactNode
|
||||
params: Promise<{ agentId: string }>
|
||||
}
|
||||
|
||||
export default async function Layout({
|
||||
children,
|
||||
params,
|
||||
}: LayoutProps) {
|
||||
const { agentId } = await params
|
||||
|
||||
return (
|
||||
<AgentDetailLayout agentId={agentId}>
|
||||
{children}
|
||||
</AgentDetailLayout>
|
||||
)
|
||||
}
|
||||
@ -1,5 +0,0 @@
|
||||
import { AgentDetailPage } from '@/features/agent-v2/pages/agent-detail-page'
|
||||
|
||||
export default function Page() {
|
||||
return <AgentDetailPage section="logs" />
|
||||
}
|
||||
@ -1,5 +0,0 @@
|
||||
import { AgentDetailPage } from '@/features/agent-v2/pages/agent-detail-page'
|
||||
|
||||
export default function Page() {
|
||||
return <AgentDetailPage section="monitoring" />
|
||||
}
|
||||
@ -1,13 +0,0 @@
|
||||
import { redirect } from '@/next/navigation'
|
||||
|
||||
type PageProps = {
|
||||
params: Promise<{ agentId: string }>
|
||||
}
|
||||
|
||||
export default async function Page({
|
||||
params,
|
||||
}: PageProps) {
|
||||
const { agentId } = await params
|
||||
|
||||
redirect(`/roster/${agentId}/configure`)
|
||||
}
|
||||
@ -1,5 +0,0 @@
|
||||
import RosterPage from '@/features/agent-v2/pages/roster-page'
|
||||
|
||||
export default function Page() {
|
||||
return <RosterPage />
|
||||
}
|
||||
@ -21,6 +21,12 @@ describe('AppDetailTop', () => {
|
||||
expect(screen.getByRole('link', { name: 'common.mainNav.home' })).toHaveAttribute('href', '/')
|
||||
})
|
||||
|
||||
it('links the Studio breadcrumb to the Studio page', () => {
|
||||
render(<AppDetailTop />)
|
||||
|
||||
expect(screen.getByRole('link', { name: 'common.menus.apps' })).toHaveAttribute('href', '/apps')
|
||||
})
|
||||
|
||||
it('keeps the back button and quick search actions', () => {
|
||||
const handleOpen = vi.fn()
|
||||
window.addEventListener(GOTO_ANYTHING_OPEN_EVENT, handleOpen)
|
||||
|
||||
@ -32,9 +32,12 @@ const AppDetailTop = () => {
|
||||
<span className="mx-1.5 shrink-0 system-md-regular text-text-quaternary">
|
||||
/
|
||||
</span>
|
||||
<span className="shrink-0 truncate system-sm-semibold-uppercase text-text-secondary">
|
||||
<Link
|
||||
href="/apps"
|
||||
className="shrink-0 truncate system-sm-semibold-uppercase text-text-secondary hover:text-text-primary"
|
||||
>
|
||||
{t('menus.apps', { ns: 'common' })}
|
||||
</span>
|
||||
</Link>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
|
||||
@ -204,7 +204,7 @@ describe('AppPublisher', () => {
|
||||
access_mode: AccessMode.PUBLIC,
|
||||
})
|
||||
mockOpenAsyncWindow.mockImplementation(async (resolver: () => Promise<string>) => {
|
||||
await resolver()
|
||||
return resolver()
|
||||
})
|
||||
Object.defineProperty(window, 'open', {
|
||||
writable: true,
|
||||
@ -397,6 +397,11 @@ describe('AppPublisher', () => {
|
||||
})
|
||||
|
||||
it('should open the installed explore page through the async window helper', async () => {
|
||||
let openedUrl = ''
|
||||
mockOpenAsyncWindow.mockImplementation(async (resolver: () => Promise<string>) => {
|
||||
openedUrl = await resolver()
|
||||
})
|
||||
|
||||
render(
|
||||
<AppPublisher
|
||||
publishedAt={Date.now()}
|
||||
@ -409,6 +414,7 @@ describe('AppPublisher', () => {
|
||||
await waitFor(() => {
|
||||
expect(mockOpenAsyncWindow).toHaveBeenCalledTimes(1)
|
||||
expect(mockFetchInstalledAppList).toHaveBeenCalledWith('app-1')
|
||||
expect(openedUrl).toBe('/installed/installed-1')
|
||||
expect(sectionProps.actions?.appURL).toBe(`https://example.com${basePath}/chat/token-1`)
|
||||
})
|
||||
})
|
||||
|
||||
@ -29,6 +29,7 @@ import {
|
||||
import EmbeddedModal from '@/app/components/app/overview/embedded'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import { buildInstalledAppPath } from '@/app/components/explore/installed-app/routes'
|
||||
import { WorkflowToolDrawer } from '@/app/components/tools/workflow-tool'
|
||||
import { useConfigureButton } from '@/app/components/tools/workflow-tool/hooks/use-configure-button'
|
||||
import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager'
|
||||
@ -238,7 +239,7 @@ const AppPublisher = ({
|
||||
throw new Error('App not found')
|
||||
const { installed_apps } = await fetchInstalledAppList(appDetail.id)
|
||||
if (installed_apps?.length > 0)
|
||||
return `${basePath}/explore/installed/${installed_apps[0]!.id}`
|
||||
return `${basePath}${buildInstalledAppPath(installed_apps[0]!.id)}`
|
||||
throw new Error('No app found in Explore')
|
||||
}, {
|
||||
onError: (err) => {
|
||||
|
||||
@ -1400,9 +1400,10 @@ describe('AppCard', () => {
|
||||
})
|
||||
|
||||
it('should handle open in explore via async window', async () => {
|
||||
let openedUrl = ''
|
||||
// Configure mockOpenAsyncWindow to actually call the callback
|
||||
mockOpenAsyncWindow.mockImplementationOnce(async (callback: () => Promise<string>) => {
|
||||
await callback()
|
||||
openedUrl = await callback()
|
||||
})
|
||||
|
||||
render(<AppCard app={mockApp} />)
|
||||
@ -1415,6 +1416,7 @@ describe('AppCard', () => {
|
||||
|
||||
await waitFor(() => {
|
||||
expect(exploreService.fetchInstalledAppList).toHaveBeenCalledWith(mockApp.id)
|
||||
expect(openedUrl).toBe('/installed/installed-1')
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@ -36,6 +36,7 @@ import { Trans, useTranslation } from 'react-i18next'
|
||||
import { AppTypeIcon } from '@/app/components/app/type-selector'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import { UserAvatarList } from '@/app/components/base/user-avatar-list'
|
||||
import { buildInstalledAppPath } from '@/app/components/explore/installed-app/routes'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
@ -165,7 +166,7 @@ const AppCardOperationsMenu: React.FC<AppCardOperationsMenuProps> = ({
|
||||
await openAsyncWindow(async () => {
|
||||
const { installed_apps } = await fetchInstalledAppList(app.id)
|
||||
if (installed_apps?.length > 0)
|
||||
return `${basePath}/explore/installed/${installed_apps[0]!.id}`
|
||||
return `${basePath}${buildInstalledAppPath(installed_apps[0]!.id)}`
|
||||
throw new Error('No app found in Explore')
|
||||
}, {
|
||||
onError: (err) => {
|
||||
|
||||
90
web/app/components/apps/app-list-creation-modals.tsx
Normal file
90
web/app/components/apps/app-list-creation-modals.tsx
Normal file
@ -0,0 +1,90 @@
|
||||
'use client'
|
||||
|
||||
import type { AppListCategory } from './hooks/use-apps-query-state'
|
||||
import dynamic from '@/next/dynamic'
|
||||
|
||||
const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-from-dsl-modal'), {
|
||||
ssr: false,
|
||||
})
|
||||
const CreateAppModal = dynamic(() => import('@/app/components/app/create-app-modal'), {
|
||||
ssr: false,
|
||||
})
|
||||
const CreateAppTemplateDialog = dynamic(() => import('@/app/components/app/create-app-dialog'), {
|
||||
ssr: false,
|
||||
})
|
||||
|
||||
export function AppListCreationModals({
|
||||
category,
|
||||
droppedDSLFile,
|
||||
showCreateFromDSLModal,
|
||||
showNewAppModal,
|
||||
showNewAppTemplateDialog,
|
||||
onPlanInfoChanged,
|
||||
onRefetch,
|
||||
onSetDroppedDSLFile,
|
||||
onSetShowCreateFromDSLModal,
|
||||
onSetShowNewAppModal,
|
||||
onSetShowNewAppTemplateDialog,
|
||||
}: {
|
||||
category: AppListCategory
|
||||
droppedDSLFile?: File
|
||||
showCreateFromDSLModal: boolean
|
||||
showNewAppModal: boolean
|
||||
showNewAppTemplateDialog: boolean
|
||||
onPlanInfoChanged: () => void
|
||||
onRefetch: () => void
|
||||
onSetDroppedDSLFile: (file?: File) => void
|
||||
onSetShowCreateFromDSLModal: (show: boolean) => void
|
||||
onSetShowNewAppModal: (show: boolean) => void
|
||||
onSetShowNewAppTemplateDialog: (show: boolean) => void
|
||||
}) {
|
||||
return (
|
||||
<>
|
||||
{showCreateFromDSLModal && (
|
||||
<CreateFromDSLModal
|
||||
show={showCreateFromDSLModal}
|
||||
onClose={() => {
|
||||
onSetShowCreateFromDSLModal(false)
|
||||
onSetDroppedDSLFile(undefined)
|
||||
}}
|
||||
onSuccess={() => {
|
||||
onSetShowCreateFromDSLModal(false)
|
||||
onSetDroppedDSLFile(undefined)
|
||||
onPlanInfoChanged()
|
||||
onRefetch()
|
||||
}}
|
||||
droppedFile={droppedDSLFile}
|
||||
/>
|
||||
)}
|
||||
{showNewAppModal && (
|
||||
<CreateAppModal
|
||||
show={showNewAppModal}
|
||||
onClose={() => onSetShowNewAppModal(false)}
|
||||
onSuccess={() => {
|
||||
onPlanInfoChanged()
|
||||
onRefetch()
|
||||
}}
|
||||
onCreateFromTemplate={() => {
|
||||
onSetShowNewAppTemplateDialog(true)
|
||||
onSetShowNewAppModal(false)
|
||||
}}
|
||||
defaultAppMode={category !== 'all' ? category : undefined}
|
||||
/>
|
||||
)}
|
||||
{showNewAppTemplateDialog && (
|
||||
<CreateAppTemplateDialog
|
||||
show={showNewAppTemplateDialog}
|
||||
onClose={() => onSetShowNewAppTemplateDialog(false)}
|
||||
onSuccess={() => {
|
||||
onPlanInfoChanged()
|
||||
onRefetch()
|
||||
}}
|
||||
onCreateFromBlank={() => {
|
||||
onSetShowNewAppModal(true)
|
||||
onSetShowNewAppTemplateDialog(false)
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
26
web/app/components/apps/app-list-tag-management-modal.tsx
Normal file
26
web/app/components/apps/app-list-tag-management-modal.tsx
Normal file
@ -0,0 +1,26 @@
|
||||
'use client'
|
||||
|
||||
import dynamic from '@/next/dynamic'
|
||||
|
||||
const TagManagementModal = dynamic(() => import('@/features/tag-management/components/tag-management-modal').then(mod => mod.TagManagementModal), {
|
||||
ssr: false,
|
||||
})
|
||||
|
||||
export function AppListTagManagementModal({
|
||||
show,
|
||||
onClose,
|
||||
onTagsChange,
|
||||
}: {
|
||||
show: boolean
|
||||
onClose: () => void
|
||||
onTagsChange: () => unknown
|
||||
}) {
|
||||
return (
|
||||
<TagManagementModal
|
||||
type="app"
|
||||
show={show}
|
||||
onClose={onClose}
|
||||
onTagsChange={onTagsChange}
|
||||
/>
|
||||
)
|
||||
}
|
||||
@ -1,6 +1,5 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { AppListQuery } from '@/contract/console/apps'
|
||||
import { cn } from '@langgenius/dify-ui/cn'
|
||||
import { keepPreviousData, useInfiniteQuery, useSuspenseQuery } from '@tanstack/react-query'
|
||||
@ -11,13 +10,14 @@ import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { CheckModal } from '@/hooks/use-pay'
|
||||
import dynamic from '@/next/dynamic'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { systemFeaturesQueryOptions } from '@/service/system-features'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import AppCard from './app-card'
|
||||
import { AppCardSkeleton } from './app-card-skeleton'
|
||||
import { AppListCreationModals } from './app-list-creation-modals'
|
||||
import AppListHeaderFilters from './app-list-header-filters'
|
||||
import { AppListTagManagementModal } from './app-list-tag-management-modal'
|
||||
import { APP_LIST_SEARCH_DEBOUNCE_MS } from './constants'
|
||||
import Empty from './empty'
|
||||
import FirstEmptyState from './first-empty-state'
|
||||
@ -27,25 +27,7 @@ import { useDSLDragDrop } from './hooks/use-dsl-drag-drop'
|
||||
import { useWorkflowOnlineUsers } from './hooks/use-workflow-online-users'
|
||||
import NewAppCard from './new-app-card'
|
||||
|
||||
const TagManagementModal = dynamic(() => import('@/features/tag-management/components/tag-management-modal').then(mod => mod.TagManagementModal), {
|
||||
ssr: false,
|
||||
})
|
||||
const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-from-dsl-modal'), {
|
||||
ssr: false,
|
||||
})
|
||||
const CreateAppModal = dynamic(() => import('@/app/components/app/create-app-modal'), {
|
||||
ssr: false,
|
||||
})
|
||||
const CreateAppTemplateDialog = dynamic(() => import('@/app/components/app/create-app-dialog'), {
|
||||
ssr: false,
|
||||
})
|
||||
|
||||
type Props = {
|
||||
controlRefreshList?: number
|
||||
}
|
||||
const List: FC<Props> = ({
|
||||
controlRefreshList = 0,
|
||||
}) => {
|
||||
function List({ controlRefreshList = 0 }: { controlRefreshList?: number }) {
|
||||
const { t } = useTranslation()
|
||||
const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions())
|
||||
const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext()
|
||||
@ -141,9 +123,8 @@ const List: FC<Props> = ({
|
||||
}
|
||||
|
||||
if (anchorRef.current && containerRef.current) {
|
||||
// Calculate dynamic rootMargin: clamps to 100-200px range, using 20% of container height as the base value for better responsiveness
|
||||
const containerHeight = containerRef.current.clientHeight
|
||||
const dynamicMargin = Math.max(100, Math.min(containerHeight * 0.2, 200)) // Clamps to 100-200px range, using 20% of container height as the base value
|
||||
const dynamicMargin = Math.max(100, Math.min(containerHeight * 0.2, 200))
|
||||
|
||||
observer = new IntersectionObserver((entries) => {
|
||||
if (entries[0]!.isIntersecting && !isLoading && !isFetchingNextPage && !error && hasMore)
|
||||
@ -151,7 +132,7 @@ const List: FC<Props> = ({
|
||||
}, {
|
||||
root: containerRef.current,
|
||||
rootMargin: `${dynamicMargin}px`,
|
||||
threshold: 0.1, // Trigger when 10% of the anchor element is visible
|
||||
threshold: 0.1,
|
||||
})
|
||||
observer.observe(anchorRef.current)
|
||||
}
|
||||
@ -196,7 +177,6 @@ const List: FC<Props> = ({
|
||||
const hasResolvedFirstPage = pages.length > 0
|
||||
const hasAnyApp = (pages[0]?.total ?? 0) > 0
|
||||
const hasActiveFilters = category !== 'all' || tagIDs.length > 0 || keywords.trim().length > 0 || debouncedKeywords.trim().length > 0 || isCreatedByMe
|
||||
// Show skeleton during initial load or when refetching with no previous data
|
||||
const showSkeleton = !emptyAppList && (isLoading || (isFetching && pages.length === 0))
|
||||
const showFirstEmptyState = !showSkeleton && !hasAnyApp && isCurrentWorkspaceEditor && (emptyAppList || (hasResolvedFirstPage && !hasActiveFilters))
|
||||
|
||||
@ -290,59 +270,26 @@ const List: FC<Props> = ({
|
||||
)}
|
||||
<CheckModal />
|
||||
<div ref={anchorRef} className="h-0"> </div>
|
||||
<TagManagementModal
|
||||
type="app"
|
||||
<AppListTagManagementModal
|
||||
show={showTagManagementModal}
|
||||
onClose={() => setShowTagManagementModal(false)}
|
||||
onTagsChange={refetch}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{showCreateFromDSLModal && (
|
||||
<CreateFromDSLModal
|
||||
show={showCreateFromDSLModal}
|
||||
onClose={() => {
|
||||
setShowCreateFromDSLModal(false)
|
||||
setDroppedDSLFile(undefined)
|
||||
}}
|
||||
onSuccess={() => {
|
||||
setShowCreateFromDSLModal(false)
|
||||
setDroppedDSLFile(undefined)
|
||||
onPlanInfoChanged()
|
||||
refetch()
|
||||
}}
|
||||
droppedFile={droppedDSLFile}
|
||||
/>
|
||||
)}
|
||||
{showNewAppModal && (
|
||||
<CreateAppModal
|
||||
show={showNewAppModal}
|
||||
onClose={() => setShowNewAppModal(false)}
|
||||
onSuccess={() => {
|
||||
onPlanInfoChanged()
|
||||
refetch()
|
||||
}}
|
||||
onCreateFromTemplate={() => {
|
||||
setShowNewAppTemplateDialog(true)
|
||||
setShowNewAppModal(false)
|
||||
}}
|
||||
defaultAppMode={category !== 'all' ? category : undefined}
|
||||
/>
|
||||
)}
|
||||
{showNewAppTemplateDialog && (
|
||||
<CreateAppTemplateDialog
|
||||
show={showNewAppTemplateDialog}
|
||||
onClose={() => setShowNewAppTemplateDialog(false)}
|
||||
onSuccess={() => {
|
||||
onPlanInfoChanged()
|
||||
refetch()
|
||||
}}
|
||||
onCreateFromBlank={() => {
|
||||
setShowNewAppModal(true)
|
||||
setShowNewAppTemplateDialog(false)
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<AppListCreationModals
|
||||
category={category}
|
||||
droppedDSLFile={droppedDSLFile}
|
||||
showCreateFromDSLModal={showCreateFromDSLModal}
|
||||
showNewAppModal={showNewAppModal}
|
||||
showNewAppTemplateDialog={showNewAppTemplateDialog}
|
||||
onPlanInfoChanged={onPlanInfoChanged}
|
||||
onRefetch={refetch}
|
||||
onSetDroppedDSLFile={setDroppedDSLFile}
|
||||
onSetShowCreateFromDSLModal={setShowCreateFromDSLModal}
|
||||
onSetShowNewAppModal={setShowNewAppModal}
|
||||
onSetShowNewAppTemplateDialog={setShowNewAppTemplateDialog}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
@ -104,9 +104,9 @@ describe('AudioBtn', () => {
|
||||
expect(call![1]).toBe(false)
|
||||
})
|
||||
|
||||
it('should call installed app endpoint for explore installed routes', async () => {
|
||||
it('should call installed app endpoint for installed app routes', async () => {
|
||||
mockUseParams({ appId: '456' })
|
||||
mockUsePathname('/explore/installed/app/456')
|
||||
mockUsePathname('/installed/456')
|
||||
|
||||
render(<AudioBtn value="test" />)
|
||||
await userEvent.click(getButton())
|
||||
|
||||
@ -4,6 +4,7 @@ import { t } from 'i18next'
|
||||
import { useState } from 'react'
|
||||
import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { isInstalledAppPath } from '@/app/components/explore/installed-app/routes'
|
||||
import { useParams, usePathname } from '@/next/navigation'
|
||||
import s from './style.module.css'
|
||||
|
||||
@ -56,7 +57,7 @@ const AudioBtn = ({
|
||||
isPublic = true
|
||||
}
|
||||
else if (params.appId) {
|
||||
if (pathname.search('explore/installed') > -1)
|
||||
if (isInstalledAppPath(pathname))
|
||||
url = `/installed-apps/${params.appId}/text-to-audio`
|
||||
else
|
||||
url = `/apps/${params.appId}/text-to-audio`
|
||||
|
||||
@ -1266,7 +1266,7 @@ describe('useChat', () => {
|
||||
|
||||
describe('createAudioPlayerManager branch cases', () => {
|
||||
it('should handle ttsUrl generation for appId with installed apps', async () => {
|
||||
vi.mocked(usePathname).mockReturnValue('/explore/installed/app')
|
||||
vi.mocked(usePathname).mockReturnValue('/installed/app-1')
|
||||
vi.mocked(useParams).mockReturnValue({ appId: 'app-1' } as ReturnType<typeof useParams>)
|
||||
|
||||
let callbacks: HookCallbacks
|
||||
|
||||
@ -30,6 +30,7 @@ import {
|
||||
getProcessedFiles,
|
||||
getProcessedFilesFromResponse,
|
||||
} from '@/app/components/base/file-uploader/utils'
|
||||
import { isInstalledAppPath } from '@/app/components/explore/installed-app/routes'
|
||||
import { NodeRunningStatus, WorkflowRunningStatus } from '@/app/components/workflow/types'
|
||||
import useTimestamp from '@/hooks/use-timestamp'
|
||||
import { useParams, usePathname } from '@/next/navigation'
|
||||
@ -217,7 +218,7 @@ export const useChat = (
|
||||
ttsIsPublic = true
|
||||
}
|
||||
else if (params.appId) {
|
||||
if (pathname.search('explore/installed') > -1)
|
||||
if (isInstalledAppPath(pathname))
|
||||
ttsUrl = `/installed-apps/${params.appId}/text-to-audio`
|
||||
else
|
||||
ttsUrl = `/apps/${params.appId}/text-to-audio`
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user