Compare commits

..

1 Commits

Author SHA1 Message Date
45e8187182 fix(workflow): keep pointer position out of reactive store 2026-05-27 17:06:07 +08:00
331 changed files with 3621 additions and 17294 deletions

1
.github/CODEOWNERS vendored
View File

@ -166,7 +166,6 @@
# Frontend - App - API Documentation
/web/app/components/develop/ @JzoNgKVO @iamjoel
/web/app/components/develop/template/*.mdx @JzoNgKVO @iamjoel @RiskeyL
# Frontend - App - Logs and Annotations
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel

View File

@ -27,7 +27,7 @@ COPY api/providers ./providers
COPY dify-agent/pyproject.toml dify-agent/README.md /app/dify-agent/
COPY dify-agent/src /app/dify-agent/src
# Trust the checked-in lock during image builds; local path sources are copied from the repository context.
RUN uv sync --frozen --no-dev --no-editable
RUN uv sync --frozen --no-dev
# production stage
FROM base AS production

View File

@ -38,8 +38,6 @@ from clients.agent_backend.request_builder import (
AgentBackendOutputConfig,
AgentBackendRunRequestBuilder,
AgentBackendWorkflowNodeRunInput,
CleanupLayerSpec,
extract_cleanup_layer_specs,
redact_for_agent_backend_log,
)
@ -70,11 +68,9 @@ __all__ = [
"AgentBackendTransportError",
"AgentBackendValidationError",
"AgentBackendWorkflowNodeRunInput",
"CleanupLayerSpec",
"DifyAgentBackendRunClient",
"FakeAgentBackendRunClient",
"FakeAgentBackendScenario",
"create_agent_backend_run_client",
"extract_cleanup_layer_specs",
"redact_for_agent_backend_log",
]

View File

@ -20,8 +20,6 @@ from dify_agent.protocol import (
RunEvent,
RunFailedEvent,
RunFailedEventData,
RunPausedEvent,
RunPausedEventData,
RunStartedEvent,
RunStatusResponse,
RunSucceededEvent,
@ -36,7 +34,6 @@ class FakeAgentBackendScenario(StrEnum):
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
class FakeAgentBackendRunClient:
@ -92,13 +89,6 @@ class FakeAgentBackendRunClient:
updated_at=_FIXED_TIME,
error="fake failure",
)
case FakeAgentBackendScenario.PAUSED:
return RunStatusResponse(
run_id=run_id,
status="paused",
created_at=_FIXED_TIME,
updated_at=_FIXED_TIME,
)
def _events(self, run_id: str) -> tuple[RunEvent, ...]:
match self.scenario:
@ -125,17 +115,3 @@ class FakeAgentBackendRunClient:
data=RunFailedEventData(error="fake failure", reason="unit_test"),
),
)
case FakeAgentBackendScenario.PAUSED:
return (
RunStartedEvent(id="1-0", run_id=run_id, created_at=_FIXED_TIME),
RunPausedEvent(
id="2-0",
run_id=run_id,
created_at=_FIXED_TIME,
data=RunPausedEventData(
reason="human_input_required",
message="Agent requested human input.",
session_snapshot=CompositorSessionSnapshot(layers=[]),
),
),
)

View File

@ -11,13 +11,11 @@ composition-driven.
from __future__ import annotations
from typing import ClassVar, cast
from typing import ClassVar
from agenton.compositor import CompositorSessionSnapshot
from agenton.compositor.schemas import LayerSessionSnapshot
from agenton.layers import ExitIntent
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
from agenton_collections.layers.pydantic_ai import PYDANTIC_AI_HISTORY_LAYER_TYPE_ID
from dify_agent.layers.dify_plugin import (
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
@ -31,7 +29,6 @@ from dify_agent.layers.execution_context import (
)
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
from dify_agent.protocol import (
DIFY_AGENT_HISTORY_LAYER_ID,
DIFY_AGENT_MODEL_LAYER_ID,
DIFY_AGENT_OUTPUT_LAYER_ID,
CreateRunRequest,
@ -48,84 +45,6 @@ WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
DIFY_PLUGIN_TOOLS_LAYER_ID = "tools"
# Layer types that hold credentials in their per-run config. These are excluded
# from the cleanup-replay composition (and from the snapshot that is sent with
# the cleanup request) because we deliberately do not persist plaintext
# credentials between runs.
_CLEANUP_EXCLUDED_LAYER_TYPES: tuple[str, ...] = (
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
)
class CleanupLayerSpec(BaseModel):
"""One layer node replayed by an Agent backend cleanup-only run.
Cleanup composition cannot include credential-bearing plugin layers, so we
persist only the non-plugin layer specs together with the original config.
Storing the config (rather than just ``name``/``type``) means cleanup does
not depend on the original build-time inputs being re-derivable.
"""
name: str
type: str
deps: dict[str, str] = Field(default_factory=dict)
metadata: dict[str, JsonValue] = Field(default_factory=dict)
config: JsonValue = None
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
def extract_cleanup_layer_specs(composition: RunComposition) -> list[CleanupLayerSpec]:
"""Project the in-flight composition into the persistable cleanup spec list.
Plugin layers are intentionally dropped (their configs hold credentials and
the lifecycle contract says "do not include an LLM layer" during cleanup).
The filtered names must later drive snapshot filtering so the agenton
compositor's name-order check still passes for the cleanup run.
"""
excluded = set(_CLEANUP_EXCLUDED_LAYER_TYPES)
specs: list[CleanupLayerSpec] = []
for layer in composition.layers:
if layer.type in excluded:
continue
config_value: JsonValue = None
if isinstance(layer.config, BaseModel):
config_value = layer.config.model_dump(mode="json", warnings=False)
else:
# ``RunLayerSpec.config`` is typed as ``LayerConfigInput`` which
# includes ``Mapping[str, object] | bytes``. In the cleanup-replay
# pipeline our builder only emits BaseModel-derived configs or
# ``None``, so the wider input alias narrows safely here.
config_value = cast(JsonValue, layer.config)
specs.append(
CleanupLayerSpec(
name=layer.name,
type=layer.type,
deps=dict(layer.deps),
metadata=dict(layer.metadata),
config=config_value,
)
)
return specs
def _filter_snapshot_to_specs(
snapshot: CompositorSessionSnapshot,
specs: list[CleanupLayerSpec],
) -> CompositorSessionSnapshot:
"""Keep only snapshot layers whose names appear in the cleanup spec list.
The agenton compositor rejects a snapshot whose layer-name sequence does
not match the active composition exactly. Cleanup-replay drops plugin
layers, so we must drop the matching snapshot entries here.
"""
kept_names = {spec.name for spec in specs}
filtered_layers: list[LayerSessionSnapshot] = [layer for layer in snapshot.layers if layer.name in kept_names]
if len(filtered_layers) == len(snapshot.layers):
return snapshot
return CompositorSessionSnapshot(schema_version=snapshot.schema_version, layers=filtered_layers)
class AgentBackendModelConfig(BaseModel):
"""API-side model/plugin selection before it is converted to Dify Agent layers."""
@ -167,8 +86,7 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
output: AgentBackendOutputConfig | None = None
tools: DifyPluginToolsLayerConfig | None = None
session_snapshot: CompositorSessionSnapshot | None = None
include_history: bool = True
suspend_on_exit: bool = True
suspend_on_exit: bool = False
metadata: dict[str, JsonValue] = Field(default_factory=dict)
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
@ -184,50 +102,6 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
class AgentBackendRunRequestBuilder:
"""Converts API product state into the public ``dify-agent`` run protocol."""
def build_cleanup_request(
self,
*,
session_snapshot: CompositorSessionSnapshot,
composition_layer_specs: list[CleanupLayerSpec],
idempotency_key: str | None = None,
metadata: dict[str, JsonValue] | None = None,
) -> CreateRunRequest:
"""Build a lifecycle-only cleanup request that replays the prior layers.
The agenton compositor enforces that the session snapshot's layer names
match the active composition in order, so cleanup must replay the same
non-plugin layer graph that produced the snapshot. Plugin layers
(``dify.plugin.llm``, ``dify.plugin.tools``) are excluded from both the
composition and the snapshot before submission because their configs
require credentials that are not persisted between runs.
"""
if not composition_layer_specs:
raise ValueError(
"build_cleanup_request requires composition_layer_specs; an empty "
"composition would fail the agent backend's snapshot validation."
)
request_metadata = dict(metadata or {})
request_metadata["agent_backend_lifecycle"] = "session_cleanup"
layers = [
RunLayerSpec(
name=spec.name,
type=spec.type,
deps=dict(spec.deps),
metadata=dict(spec.metadata),
config=spec.config,
)
for spec in composition_layer_specs
]
filtered_snapshot = _filter_snapshot_to_specs(session_snapshot, composition_layer_specs)
return CreateRunRequest(
composition=RunComposition(layers=layers),
purpose="workflow_node",
idempotency_key=idempotency_key,
metadata=request_metadata,
session_snapshot=filtered_snapshot,
on_exit=LayerExitSignals(default=ExitIntent.DELETE),
)
def build_for_workflow_node(self, run_input: AgentBackendWorkflowNodeRunInput) -> CreateRunRequest:
"""Build a workflow Agent Node run request without defining another wire schema."""
layers: list[RunLayerSpec] = []
@ -261,20 +135,6 @@ class AgentBackendRunRequestBuilder:
metadata=run_input.metadata,
config=run_input.execution_context,
),
]
)
if run_input.include_history:
layers.append(
RunLayerSpec(
name=DIFY_AGENT_HISTORY_LAYER_ID,
type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID,
metadata={**run_input.metadata, "origin": "agent_session_history"},
)
)
layers.extend(
[
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,

View File

@ -11,7 +11,7 @@ from controllers.console.app.error import (
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.app_config.entities import ModelConfig
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider
@ -22,7 +22,7 @@ from core.llm_generator.llm_generator import LLMGenerator
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import App
from services.workflow_service import WorkflowService
@ -64,9 +64,9 @@ class RuleGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str):
def post(self):
args = RuleGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
@ -93,9 +93,9 @@ class RuleCodeGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str):
def post(self):
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
code_result = LLMGenerator.generate_code(
@ -125,9 +125,9 @@ class RuleStructuredOutputGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str):
def post(self):
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
structured_output = LLMGenerator.generate_structured_output(
@ -157,9 +157,9 @@ class InstructionGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str):
def post(self):
args = InstructionGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args.language)), None

View File

@ -11,16 +11,11 @@ from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import to_timestamp
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.enums import AppMCPServerStatus
from models.model import App, AppMCPServer
@ -97,8 +92,8 @@ class AppMCPServerController(Resource):
@login_required
@setup_required
@edit_permission_required
@with_current_tenant_id
def post(self, current_tenant_id: str, app_model: App):
def post(self, app_model: App):
_, current_tenant_id = current_account_with_tenant()
payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
description = payload.description
@ -168,8 +163,8 @@ class AppMCPServerRefreshController(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_tenant_id
def get(self, current_tenant_id: str, server_id: UUID):
def get(self, server_id: UUID):
_, current_tenant_id = current_account_with_tenant()
server = db.session.scalar(
select(AppMCPServer)
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)

View File

@ -83,14 +83,13 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
# create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True)
# Refresh the url signature before returning it to client.
match value:
case FileSegment():
file = value.value
if isinstance(value, FileSegment):
file = value.value
file.remote_url = file.generate_url()
elif isinstance(value, ArrayFileSegment):
files = value.value
for file in files:
file.remote_url = file.generate_url()
case ArrayFileSegment():
files = value.value
for file in files:
file.remote_url = file.generate_url()
return _convert_values_to_json_serializable_object(value)

View File

@ -169,12 +169,9 @@ class DatasetDocumentSegmentListApi(Resource):
# Use database-specific methods for JSON array search
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
# Guard with jsonb_typeof to avoid "cannot extract elements from a scalar" error
# when keywords is null or a non-array JSON value.
keywords_condition = func.array_to_string(
func.array(
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
.where(func.jsonb_typeof(cast(DocumentSegment.keywords, JSONB)) == "array")
.correlate(DocumentSegment)
.scalar_subquery()
),

View File

@ -8,17 +8,12 @@ from pydantic import BaseModel, Field, field_validator
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
)
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import uuid_value
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
@ -143,8 +138,9 @@ class DefaultModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str):
def get(self):
_, tenant_id = current_account_with_tenant()
args = ParserGetDefault.model_validate(request.args.to_dict(flat=True))
model_provider_service = ModelProviderService()
@ -160,8 +156,9 @@ class DefaultModelApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def post(self, tenant_id: str):
def post(self):
_, tenant_id = current_account_with_tenant()
args = ParserPostDefault.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_settings = args.model_settings
@ -192,8 +189,9 @@ class ModelProviderModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, provider):
def get(self, provider):
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
@ -204,9 +202,9 @@ class ModelProviderModelApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
def post(self, provider: str):
# To save the model's load balance configs
_, tenant_id = current_account_with_tenant()
args = ParserPostModels.model_validate(console_ns.payload)
if args.config_from == "custom-model":
@ -251,8 +249,9 @@ class ModelProviderModelApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def delete(self, tenant_id: str, provider: str):
def delete(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -269,8 +268,9 @@ class ModelProviderModelCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True))
model_provider_service = ModelProviderService()
@ -323,8 +323,9 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = ParserCreateCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -354,8 +355,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def put(self, current_tenant_id: str, provider: str):
def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = ParserUpdateCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -381,8 +382,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def delete(self, current_tenant_id: str, provider: str):
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = ParserDeleteCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -405,8 +406,8 @@ class ModelProviderModelCredentialSwitchApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = ParserSwitch.model_validate(console_ns.payload)
service = ModelProviderService()
@ -429,8 +430,9 @@ class ModelProviderModelEnableApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def patch(self, tenant_id: str, provider: str):
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -450,8 +452,9 @@ class ModelProviderModelDisableApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def patch(self, tenant_id: str, provider: str):
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -477,8 +480,8 @@ class ModelProviderModelValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = ParserValidate.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -512,9 +515,9 @@ class ModelProviderModelParameterRuleApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
def get(self, provider: str):
args = ParserParameter.model_validate(request.args.to_dict(flat=True))
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules(
@ -529,8 +532,8 @@ class ModelProviderAvailableModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, model_type: str):
def get(self, model_type: str):
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)

View File

@ -4,7 +4,7 @@ from datetime import UTC, datetime
from flask import request
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from werkzeug.exceptions import BadRequest, NotFound
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
@ -17,17 +17,18 @@ from controllers.openapi._models import (
SessionRow,
WorkspacePayload,
)
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.oauth_bearer import (
Scope,
TokenType,
ACCEPT_USER_ANY,
AuthContext,
SubjectType,
get_auth_ctx,
validate_bearer,
)
from libs.rate_limit import (
LIMIT_ME_PER_ACCOUNT,
LIMIT_ME_PER_EMAIL,
enforce,
)
from services.account_service import AccountService, TenantService
@ -41,18 +42,32 @@ from services.oauth_device_flow import (
@openapi_ns.route("/account")
class AccountApi(Resource):
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, *, auth_data: AuthData):
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{auth_data.account_id}")
@validate_bearer(accept=ACCEPT_USER_ANY)
def get(self):
ctx = get_auth_ctx()
account_id_str = str(auth_data.account_id) if auth_data.account_id else None
account = AccountService.get_account_by_id(db.session, account_id_str) if account_id_str else None
memberships = TenantService.get_account_memberships(db.session, account_id_str) if account_id_str else []
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
else:
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
return AccountResponse(
subject_type=ctx.subject_type,
subject_email=ctx.subject_email,
subject_issuer=ctx.subject_issuer,
account=None,
workspaces=[],
default_workspace_id=None,
).model_dump(mode="json")
account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None
memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
default_ws_id = _pick_default_workspace(memberships)
return AccountResponse(
subject_type="account",
subject_email=account.email if account else None,
subject_type=ctx.subject_type,
subject_email=ctx.subject_email or (account.email if account else None),
account=_account_payload(account) if account else None,
workspaces=[_workspace_payload(m) for m in memberships],
default_workspace_id=default_ws_id,
@ -62,17 +77,19 @@ class AccountApi(Resource):
@openapi_ns.route("/account/sessions/self")
class AccountSessionsSelfApi(Resource):
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def delete(self, *, auth_data: AuthData):
revoke_oauth_token(db.session, redis_client, str(auth_data.token_id))
@validate_bearer(accept=ACCEPT_USER_ANY)
def delete(self):
ctx = get_auth_ctx()
_require_oauth_subject(ctx)
revoke_oauth_token(db.session, redis_client, str(ctx.token_id))
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
@openapi_ns.route("/account/sessions")
class AccountSessionsApi(Resource):
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, *, auth_data: AuthData):
@validate_bearer(accept=ACCEPT_USER_ANY)
def get(self):
ctx = get_auth_ctx()
now = datetime.now(UTC)
page = int(request.args.get("page", "1"))
@ -105,9 +122,10 @@ class AccountSessionsApi(Resource):
@openapi_ns.route("/account/sessions/<string:session_id>")
class AccountSessionByIdApi(Resource):
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def delete(self, session_id: str, *, auth_data: AuthData):
@validate_bearer(accept=ACCEPT_USER_ANY)
def delete(self, session_id: str):
ctx = get_auth_ctx()
_require_oauth_subject(ctx)
# 404 (not 403) on cross-subject so the endpoint doesn't leak
# token IDs that belong to other subjects.
@ -118,6 +136,13 @@ class AccountSessionByIdApi(Resource):
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
def _require_oauth_subject(ctx: AuthContext) -> None:
if not ctx.source.startswith("oauth"):
raise BadRequest(
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
)
def _iso(dt: datetime | None) -> str | None:
if dt is None:
return None

View File

@ -16,8 +16,7 @@ import services
from controllers.openapi import openapi_ns
from controllers.openapi._audit import emit_app_run
from controllers.openapi._models import AppRunRequest
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
@ -125,9 +124,8 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
class AppRunApi(Resource):
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
@openapi_ns.response(200, "Run result (SSE stream)")
@auth_router.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, *, auth_data: AuthData):
app_model, caller, caller_kind = auth_data.require_app_context()
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
body = request.get_json(silent=True) or {}
try:
payload = AppRunRequest.model_validate(body)
@ -160,9 +158,8 @@ class AppRunApi(Resource):
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
class AppRunTaskStopApi(Resource):
@openapi_ns.response(200, "Task stopped")
@auth_router.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, task_id: str, *, auth_data: AuthData):
app_model, caller, caller_kind = auth_data.require_app_context()
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
AppQueueManager.set_stop_flag_no_user_check(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@ -1,4 +1,9 @@
"""GET /openapi/v1/apps and per-app reads."""
"""GET /openapi/v1/apps and per-app reads.
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
is last → outermost → publishes the auth ContextVar before `require_scope`
reads it.
"""
from __future__ import annotations
@ -23,17 +28,31 @@ from controllers.openapi._models import (
AppListRow,
TagItem,
)
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.surface_gate import accept_subjects
from controllers.service_api.app.error import AppUnavailableError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from extensions.ext_database import db
from libs.oauth_bearer import Scope, TokenType
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
AuthContext,
Scope,
SubjectType,
get_auth_ctx,
require_scope,
require_workspace_member,
validate_bearer,
)
from models import App
from services.account_service import TenantService
from services.app_service import AppListParams, AppService
from services.tag_service import TagService
_APPS_READ_DECORATORS = [
require_scope(Scope.APPS_READ),
accept_subjects(SubjectType.ACCOUNT),
validate_bearer(accept=ACCEPT_USER_ANY),
]
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
@ -47,9 +66,13 @@ _EMPTY_PARAMETERS: dict[str, Any] = {
class AppReadResource(Resource):
"""Base for per-app read endpoints; subclasses call `_load()` for membership/exists checks."""
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
method_decorators = _APPS_READ_DECORATORS
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
ctx: AuthContext = get_auth_ctx()
def _load(self, app_id: str, workspace_id: str | None = None) -> App:
try:
parsed_uuid = _uuid.UUID(app_id)
is_uuid = True
@ -76,7 +99,8 @@ class AppReadResource(Resource):
raise Conflict("".join(lines))
app = matches[0]
return app
require_workspace_member(ctx, str(app.tenant_id))
return app, ctx
def parameters_payload(app: App) -> dict:
@ -90,14 +114,13 @@ def parameters_payload(app: App) -> dict:
class AppDescribeApi(AppReadResource):
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, app_id: str, *, auth_data: AuthData):
def get(self, app_id: str):
try:
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
app = self._load(app_id, workspace_id=query.workspace_id)
app, _ = self._load(app_id, workspace_id=query.workspace_id)
requested = query.fields
want_info = requested is None or "info" in requested
@ -145,16 +168,20 @@ class AppDescribeApi(AppReadResource):
@openapi_ns.route("/apps")
class AppListApi(Resource):
method_decorators = _APPS_READ_DECORATORS
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, *, auth_data: AuthData):
def get(self):
ctx: AuthContext = get_auth_ctx()
try:
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
workspace_id = query.workspace_id
require_workspace_member(ctx, workspace_id)
empty = (
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
@ -210,7 +237,7 @@ class AppListApi(Resource):
openapi_visible=True,
)
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params)
pagination = AppService().get_paginate_apps(str(ctx.account_id), workspace_id, params)
if pagination is None:
return empty

View File

@ -18,27 +18,37 @@ from controllers.openapi._models import (
PermittedExternalAppsListQuery,
PermittedExternalAppsListResponse,
)
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData, Edition
from controllers.openapi.auth.surface_gate import accept_subjects
from extensions.ext_database import db
from libs.oauth_bearer import Scope, TokenType
from libs.device_flow_security import enterprise_only
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
Scope,
SubjectType,
require_scope,
validate_bearer,
)
from models import App
from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.app_permitted_service import list_permitted_apps
from services.openapi.license_gate import license_required
@openapi_ns.route("/permitted-external-apps")
class PermittedExternalAppsListApi(Resource):
method_decorators = [
require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL),
license_required,
accept_subjects(SubjectType.EXTERNAL_SSO),
validate_bearer(accept=ACCEPT_USER_ANY),
enterprise_only,
]
@openapi_ns.response(
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
)
@auth_router.guard(
scope=Scope.APPS_READ_PERMITTED_EXTERNAL,
allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}),
edition=frozenset({Edition.EE}),
)
def get(self, *, auth_data: AuthData):
def get(self):
try:
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:

View File

@ -1,3 +1,3 @@
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
__all__ = ["auth_router"]
__all__ = ["OAUTH_BEARER_PIPELINE"]

View File

@ -1,64 +1,46 @@
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
inline — they don't need `AppAuthzCheck`/`CallerMount`.
"""
from __future__ import annotations
from controllers.openapi.auth.conditions import (
EDITION_CE,
EDITION_EE,
LOADED_APP_IS_PRIVATE,
PATH_HAS_APP_ID,
WEBAPP_AUTH_ENABLED,
from controllers.openapi.auth.pipeline import Pipeline
from controllers.openapi.auth.steps import (
AppAuthzCheck,
AppResolver,
BearerCheck,
CallerMount,
ScopeCheck,
SurfaceCheck,
WorkspaceMembershipCheck,
)
from controllers.openapi.auth.data import Edition
from controllers.openapi.auth.flow import When
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
from controllers.openapi.auth.prepare import (
load_account,
load_app,
load_app_access_mode,
load_tenant,
resolve_external_user,
from controllers.openapi.auth.strategies import (
AccountMounter,
AclStrategy,
AppAuthzStrategy,
EndUserMounter,
MembershipStrategy,
)
from controllers.openapi.auth.verify import (
check_acl,
check_app_access,
check_membership,
check_private_app_permission,
check_scope,
)
from libs.oauth_bearer import TokenType
from libs.oauth_bearer import SubjectType
from services.feature_service import FeatureService
account_pipeline = AuthPipeline(
prepare=[
When(PATH_HAS_APP_ID, then=load_app),
When(PATH_HAS_APP_ID, then=load_tenant),
load_account, # all tokens here are account tokens
When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode),
],
auth=[
check_scope,
When(EDITION_CE & PATH_HAS_APP_ID, then=check_membership),
When(EDITION_EE & PATH_HAS_APP_ID & ~WEBAPP_AUTH_ENABLED, then=check_app_access),
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
],
)
external_sso_pipeline = AuthPipeline(
prepare=[
When(PATH_HAS_APP_ID, then=load_app),
When(PATH_HAS_APP_ID, then=load_tenant),
When(PATH_HAS_APP_ID, then=resolve_external_user),
When(PATH_HAS_APP_ID, then=load_app_access_mode),
],
auth=[
check_scope,
When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl),
When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
],
)
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
if FeatureService.get_system_features().webapp_auth.enabled:
return AclStrategy()
return MembershipStrategy()
auth_router = PipelineRouter(
{
TokenType.OAUTH_ACCOUNT: PipelineRoute(account_pipeline),
TokenType.OAUTH_EXTERNAL_SSO: PipelineRoute(external_sso_pipeline, required_edition=frozenset({Edition.EE})),
}
OAUTH_BEARER_PIPELINE = Pipeline(
BearerCheck(),
SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})),
ScopeCheck(),
AppResolver(),
WorkspaceMembershipCheck(),
AppAuthzCheck(_resolve_app_authz_strategy),
CallerMount(AccountMounter(), EndUserMounter()),
)

View File

@ -1,53 +0,0 @@
from __future__ import annotations
from collections.abc import Callable
from controllers.openapi.auth.data import AuthData, Edition, RequestContext, current_edition
from libs.oauth_bearer import TokenType
from services.enterprise.enterprise_service import WebAppAccessMode
from services.feature_service import FeatureService
CondFn = Callable[[RequestContext, AuthData | None], bool]
class Cond:
def __init__(self, fn: CondFn) -> None:
self._fn = fn
def __call__(self, ctx: RequestContext, data: AuthData | None = None) -> bool:
return self._fn(ctx, data)
def __and__(self, other: Cond) -> Cond:
return Cond(lambda ctx, data: self(ctx, data) and other(ctx, data))
def __or__(self, other: Cond) -> Cond:
return Cond(lambda ctx, data: self(ctx, data) or other(ctx, data))
def __invert__(self) -> Cond:
return Cond(lambda ctx, data: not self(ctx, data))
def request_cond(fn: Callable[[RequestContext], bool]) -> Cond:
return Cond(lambda ctx, _: fn(ctx))
def data_cond(fn: Callable[[AuthData], bool]) -> Cond:
return Cond(lambda _, data: data is not None and fn(data))
def config_cond(fn: Callable[[], bool]) -> Cond:
return Cond(lambda _, __: fn())
TOKEN_IS_OAUTH_ACCOUNT = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_ACCOUNT)
TOKEN_IS_OAUTH_EXTERNAL_SSO = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_EXTERNAL_SSO)
PATH_HAS_APP_ID = request_cond(lambda ctx: "app_id" in ctx.path_params)
EDITION_CE = config_cond(lambda: current_edition() == Edition.CE)
EDITION_EE = config_cond(lambda: current_edition() == Edition.EE)
EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE)

View File

@ -0,0 +1,68 @@
"""Mutable per-request context for the openapi auth pipeline.
Every field starts None / empty and is filled in by a step. The pipeline
is the only thing that should construct or mutate Context — handlers
read populated values via the decorator's kwargs unpacking.
Context is intentionally decoupled from Flask's ``Request``: the pipeline
guard extracts whatever transport-level inputs the steps need (bearer
token, path params) at the boundary and writes them into Context fields,
so steps stay testable without a request object and won't leak coupling
to a specific framework.
"""
from __future__ import annotations
import uuid
from collections.abc import Mapping
from contextvars import Token
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Literal, Protocol
from werkzeug.exceptions import Unauthorized
from libs.oauth_bearer import AuthContext, Scope, SubjectType
if TYPE_CHECKING:
from models import App, Tenant
@dataclass
class Context:
required_scope: Scope
bearer_token: str | None = None
path_params: Mapping[str, str] = field(default_factory=dict)
subject_type: SubjectType | None = None
subject_email: str | None = None
subject_issuer: str | None = None
account_id: uuid.UUID | None = None
scopes: frozenset[Scope] = field(default_factory=frozenset)
token_id: uuid.UUID | None = None
token_hash: str | None = None
cached_verified_tenants: dict[str, bool] | None = None
source: str | None = None
expires_at: datetime | None = None
app: App | None = None
tenant: Tenant | None = None
caller: object | None = None
caller_kind: Literal["account", "end_user"] | None = None
auth_ctx_reset_token: Token[AuthContext] | None = None
@property
def must_tenant(self) -> Tenant:
if not self.tenant:
raise Unauthorized("tenant is not associated")
return self.tenant
@property
def must_subject_type(self) -> SubjectType:
if not self.subject_type:
raise Unauthorized("subject_type unset — BearerCheck did not run")
return self.subject_type
class Step(Protocol):
"""One responsibility. Mutate ctx or raise to short-circuit."""
def __call__(self, ctx: Context) -> None: ...

View File

@ -1,69 +0,0 @@
from __future__ import annotations
import uuid
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field
from werkzeug.exceptions import InternalServerError
from configs import dify_config
from libs.oauth_bearer import Scope, TokenType
from models.account import Account, Tenant
from models.model import App, EndUser
from services.enterprise.enterprise_service import WebAppAccessMode
class Edition(StrEnum):
CE = "ce"
EE = "ee"
SAAS = "saas"
def current_edition() -> Edition:
if dify_config.EDITION == "CLOUD":
return Edition.SAAS
if dify_config.ENTERPRISE_ENABLED:
return Edition.EE
return Edition.CE
class ExternalIdentity(BaseModel):
model_config = ConfigDict(frozen=True)
email: str
issuer: str | None = None
class RequestContext(BaseModel):
model_config = ConfigDict(frozen=True)
token_type: TokenType
scope: Scope | None = None
path_params: dict[str, str]
class AuthData(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
required_scope: Scope | None = None
token_type: TokenType
account_id: uuid.UUID | None = None
token_hash: str
token_id: uuid.UUID | None = None
scopes: frozenset[Scope]
tenants: dict[str, bool] = Field(default_factory=dict)
external_identity: ExternalIdentity | None = None
path_params: dict[str, str] = Field(default_factory=dict)
app: App | None = None
tenant: Tenant | None = None
app_access_mode: WebAppAccessMode | None = None
caller: Account | EndUser | None = None
caller_kind: Literal["account", "end_user"] | None = None
def require_app_context(self) -> tuple[App, Account | EndUser, Literal["account", "end_user"]]:
if self.app is None or self.caller is None or self.caller_kind is None:
raise InternalServerError("pipeline_invariant_violated: app context missing")
return self.app, self.caller, self.caller_kind

View File

@ -1,19 +0,0 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from controllers.openapi.auth.conditions import Cond
from controllers.openapi.auth.data import AuthData, RequestContext
class When:
def __init__(self, condition: Cond, *, then: Callable[[Any], None]) -> None:
self.condition = condition
self._step = then
def applies(self, ctx: RequestContext, data: AuthData | None = None) -> bool:
return self.condition(ctx, data)
def __call__(self, arg: Any) -> None:
self._step(arg)

View File

@ -1,209 +1,51 @@
"""Auth pipeline — entry point for all openapi auth.
"""Pipeline IS the auth scheme.
`PipelineRouter.guard()` is the only attachment point for endpoints.
`AuthPipeline` is a pure step-runner with no routing concerns.
`PipelineRoute` binds a pipeline to optional edition requirements.
`Pipeline.guard(scope=…)` is the only attachment point for endpoints
that is the design lock-in: forgetting an auth layer is structurally
impossible because there is no "sometimes wrap, sometimes don't" choice.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from functools import wraps
from typing import Any
from flask import current_app, request
from flask_login import user_logged_in
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from flask import request
from controllers.openapi._audit import emit_wrong_surface
from controllers.openapi.auth.data import (
AuthData,
Edition,
ExternalIdentity,
RequestContext,
current_edition,
)
from controllers.openapi.auth.flow import When
from libs.oauth_bearer import (
AuthContext,
Scope,
TokenType,
extract_bearer,
get_authenticator,
reset_auth_ctx,
set_auth_ctx,
)
from services.feature_service import FeatureService, LicenseStatus
from controllers.openapi.auth.context import Context, Step
from libs.oauth_bearer import Scope, extract_bearer, reset_auth_ctx
class AuthPipeline:
"""Pure step-runner — no routing, no guard.
class Pipeline:
def __init__(self, *steps: Step) -> None:
self._steps = steps
Both `prepare` and `auth` steps receive the same `AuthData` instance.
`prepare` steps populate it; `auth` steps validate it.
"""
def run(self, ctx: Context) -> None:
for step in self._steps:
step(ctx)
def __init__(self, prepare: list, auth: list) -> None:
self._prepare = prepare
self._auth = auth
def _run(
self,
identity: AuthContext,
args: tuple,
kwargs: dict,
view: Callable,
*,
scope: Scope | None,
) -> Any:
req_ctx = RequestContext(
token_type=identity.token_type,
scope=scope,
path_params=dict(request.view_args or {}),
)
data = AuthData(
token_type=identity.token_type,
account_id=identity.account_id,
token_hash=identity.token_hash,
token_id=identity.token_id,
scopes=frozenset(identity.scopes),
tenants=dict(identity.verified_tenants),
required_scope=scope,
path_params=dict(req_ctx.path_params),
external_identity=(
ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer)
if identity.subject_email
else None
),
)
for step in self._prepare:
if _should_run(step, req_ctx, data=None):
step(data)
for step in self._auth:
if _should_run(step, req_ctx, data=data):
step(data)
reset_token = set_auth_ctx(identity)
if data.caller:
_mount_flask_login(data.caller)
try:
kwargs["auth_data"] = data
return view(*args, **kwargs)
finally:
reset_auth_ctx(reset_token)
@dataclass(frozen=True)
class PipelineRoute:
pipeline: AuthPipeline
required_edition: frozenset[Edition] | None = None
class PipelineRouter:
"""Entry point for openapi auth.
`guard()` is the decorator that endpoints attach to. It applies
global gates (edition, token type) then dispatches to the matching
`PipelineRoute` for the token type.
"""
def __init__(self, routes: dict[TokenType, PipelineRoute]) -> None:
self._routes = routes
def guard(
self,
*,
scope: Scope | None = None,
allowed_token_types: frozenset[TokenType] | None = None,
edition: frozenset[Edition] | None = None,
) -> Callable:
def decorator(view: Callable) -> Callable:
def guard(self, *, scope: Scope):
def decorator(view):
@wraps(view)
def decorated(*args: Any, **kwargs: Any) -> Any:
return self._execute(
args,
kwargs,
view,
scope=scope,
allowed_token_types=allowed_token_types,
edition=edition,
def decorated(*args, **kwargs):
# Extract transport-level inputs at the boundary so steps
# stay decoupled from Flask's request object.
ctx = Context(
required_scope=scope,
bearer_token=extract_bearer(request),
path_params=dict(request.view_args or {}),
)
try:
self.run(ctx)
kwargs.update(
app_model=ctx.app,
caller=ctx.caller,
caller_kind=ctx.caller_kind,
)
return view(*args, **kwargs)
finally:
if ctx.auth_ctx_reset_token is not None:
reset_auth_ctx(ctx.auth_ctx_reset_token)
return decorated
return decorator
def _execute(
self,
args: tuple,
kwargs: dict,
view: Callable,
*,
scope: Scope | None,
allowed_token_types: frozenset[TokenType] | None,
edition: frozenset[Edition] | None,
) -> Any:
# 404 not 403 — this edition doesn't expose the feature at all
if edition is not None and current_edition() not in edition:
raise NotFound()
license_checked = False
if edition is not None and Edition.EE in edition:
_check_license()
license_checked = True
token = extract_bearer(request)
if not token:
raise Unauthorized("bearer required")
identity = get_authenticator().authenticate(token)
if allowed_token_types is not None and identity.token_type not in allowed_token_types:
emit_wrong_surface(
subject_type=_subject_type_str(identity),
attempted_path=request.path,
client_id=getattr(identity, "client_id", None),
token_id=str(identity.token_id) if identity.token_id else None,
)
raise Forbidden("unsupported_token_type")
route = self._routes.get(identity.token_type)
if route is None:
raise Forbidden("unsupported_token_type")
if route.required_edition is not None:
if current_edition() not in route.required_edition:
raise Forbidden("external_sso_requires_ee")
if not license_checked and Edition.EE in route.required_edition:
_check_license()
return route.pipeline._run(identity, args, kwargs, view, scope=scope)
def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool:
if isinstance(step, When):
return step.applies(req_ctx, data)
return True
def _subject_type_str(identity: Any) -> str | None:
subject = getattr(identity, "subject_type", None)
if subject is None:
return None
return subject.value if hasattr(subject, "value") else str(subject)
def _check_license() -> None:
settings = FeatureService.get_system_features()
if settings.license.status in {LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST}:
raise Forbidden("license_invalid")
def _mount_flask_login(user: Any) -> None:
current_app.login_manager._update_request_context_with_user(user) # type: ignore[attr-defined]
user_logged_in.send(current_app._get_current_object(), user=user) # type: ignore[attr-defined]

View File

@ -1,67 +0,0 @@
from __future__ import annotations
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound, Unauthorized
from controllers.openapi.auth.data import AuthData
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from models.account import TenantStatus
from services.account_service import AccountService, TenantService
from services.app_service import AppService
from services.end_user_service import EndUserService
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
def load_app(data: AuthData) -> None:
app_id = data.path_params["app_id"]
app = AppService.get_app_by_id(db.session, app_id)
if not app or app.status != "normal":
raise NotFound("app not found")
if not app.enable_api:
raise Forbidden("service_api_disabled")
data.app = app
def load_tenant(data: AuthData) -> None:
if data.app is None:
raise InternalServerError("pipeline_invariant_violated: app not loaded before load_tenant")
tenant = TenantService.get_tenant_by_id(db.session, str(data.app.tenant_id))
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("workspace unavailable")
data.tenant = tenant
def load_account(data: AuthData) -> None:
account = AccountService.get_account_by_id(db.session, str(data.account_id))
if account is None:
raise Unauthorized("account not found")
if data.tenant:
account.current_tenant = data.tenant
data.caller = account
data.caller_kind = "account"
def resolve_external_user(data: AuthData) -> None:
if data.tenant is None or data.app is None or data.external_identity is None:
raise Unauthorized("missing context for external user resolution")
end_user = EndUserService.get_or_create_end_user_by_type(
InvokeFrom.OPENAPI,
tenant_id=str(data.tenant.id),
app_id=str(data.app.id),
user_id=data.external_identity.email,
)
data.caller = end_user
data.caller_kind = "end_user"
def load_app_access_mode(data: AuthData) -> None:
if data.app is None:
return
try:
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(data.app.id))
if settings is None:
data.app_access_mode = None
return
data.app_access_mode = WebAppAccessMode(settings.access_mode)
except ValueError:
data.app_access_mode = None

View File

@ -0,0 +1,170 @@
"""Pipeline steps. Each is one responsibility.
`BearerCheck` is the only step that touches the token registry; downstream
steps see only the populated `Context`. `BearerCheck` also publishes the
resolved identity to the openapi auth ``ContextVar`` (the same one the
decorator-level :func:`libs.oauth_bearer.validate_bearer` writes to) so the
surface gate and any handler reading the request-scoped context has a single
source of truth across both auth-attach paths. The reset token is stashed
on `ctx.auth_ctx_reset_token`; `Pipeline.guard` resets the ContextVar in
its `finally` so worker-thread reuse can't leak identity across requests.
"""
from __future__ import annotations
from collections.abc import Callable
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
from configs import dify_config
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
from controllers.openapi.auth.surface_gate import check_surface
from extensions.ext_database import db
from libs.oauth_bearer import (
AuthContext,
InvalidBearerError,
Scope,
SubjectType,
check_workspace_membership,
get_authenticator,
set_auth_ctx,
)
from models import TenantStatus
from services.account_service import TenantService
from services.app_service import AppService
class BearerCheck:
"""Resolve bearer → populate identity fields. Rate-limit is enforced
inside `BearerAuthenticator.authenticate`, so no separate step here.
Also publishes the resolved `AuthContext` via
:func:`libs.oauth_bearer.set_auth_ctx` — same shape the decorator-level
``validate_bearer`` writes — so the surface gate + downstream readers
don't see two different identity sources. The reset token is parked on
``ctx.auth_ctx_reset_token`` for `Pipeline.guard` to consume."""
def __call__(self, ctx: Context) -> None:
if not ctx.bearer_token:
raise Unauthorized("bearer required")
try:
authn = get_authenticator().authenticate(ctx.bearer_token)
except InvalidBearerError as e:
raise Unauthorized(str(e))
ctx.subject_type = authn.subject_type
ctx.subject_email = authn.subject_email
ctx.subject_issuer = authn.subject_issuer
ctx.account_id = authn.account_id
ctx.scopes = frozenset(authn.scopes)
ctx.source = authn.source
ctx.token_id = authn.token_id
ctx.expires_at = authn.expires_at
ctx.token_hash = authn.token_hash
ctx.cached_verified_tenants = dict(authn.verified_tenants)
ctx.auth_ctx_reset_token = set_auth_ctx(authn)
class ScopeCheck:
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
def __call__(self, ctx: Context) -> None:
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
return
raise Forbidden("insufficient_scope")
class SurfaceCheck:
"""Reject the request if the resolved subject is not in `accepted`."""
def __init__(self, *, accepted: frozenset[SubjectType]) -> None:
self._accepted = accepted
def __call__(self, ctx: Context) -> None:
check_surface(self._accepted)
class AppResolver:
"""Read ``app_id`` from ``ctx.path_params``; populate ctx.app + ctx.tenant.
Every endpoint using the OAuth bearer pipeline must declare
``<string:app_id>`` in its route — that is the design lock-in (no body /
header coupling). ``Pipeline.guard`` lifts ``request.view_args`` into
``ctx.path_params`` at the boundary so this step doesn't need to know
about the request object.
"""
def __call__(self, ctx: Context) -> None:
app_id = ctx.path_params.get("app_id")
if not app_id:
raise BadRequest("app_id is required in path")
app = AppService.get_app_by_id(db.session, app_id)
if not app or app.status != "normal":
raise NotFound("app not found")
if not app.enable_api:
raise Forbidden("service_api_disabled")
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("workspace unavailable")
ctx.app, ctx.tenant = app, tenant
class WorkspaceMembershipCheck:
"""Layer 0 — workspace membership gate.
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
(dfoa_) only — SSO subjects skip.
"""
def __call__(self, ctx: Context) -> None:
if dify_config.ENTERPRISE_ENABLED:
return
if ctx.subject_type != SubjectType.ACCOUNT:
return
if ctx.account_id is None or ctx.tenant is None:
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
if ctx.token_hash is None:
raise Unauthorized("token_hash unset — BearerCheck did not run")
check_workspace_membership(
account_id=ctx.account_id,
tenant_id=ctx.must_tenant.id,
token_hash=ctx.token_hash,
cached_verdicts=ctx.cached_verified_tenants or {},
)
class AppAuthzCheck:
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
self._resolve = resolve_strategy
def __call__(self, ctx: Context) -> None:
if not self._resolve().authorize(ctx):
raise Forbidden("subject_no_app_access")
class CallerMount:
def __init__(self, *mounters: CallerMounter) -> None:
self._mounters = mounters
def __call__(self, ctx: Context) -> None:
if ctx.subject_type is None:
raise Unauthorized("subject_type unset — BearerCheck did not run")
for m in self._mounters:
if m.applies_to(ctx.must_subject_type):
m.mount(ctx)
return
raise Unauthorized("no caller mounter for subject type")
__all__ = [
"AppAuthzCheck",
"AppResolver",
"AuthContext",
"BearerCheck",
"CallerMount",
"ScopeCheck",
"SurfaceCheck",
"WorkspaceMembershipCheck",
]

View File

@ -0,0 +1,168 @@
"""Strategy classes for the openapi auth pipeline.
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
vary along independent axes; each strategy is one class so the pipeline
composition stays a flat list.
"""
from __future__ import annotations
from typing import Protocol
from flask import current_app
from flask_login import user_logged_in
from controllers.openapi.auth.context import Context
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.oauth_bearer import SubjectType
from services.account_service import AccountService, TenantService
from services.end_user_service import EndUserService
from services.enterprise.enterprise_service import (
EnterpriseService,
WebAppAccessMode,
)
class AppAuthzStrategy(Protocol):
def authorize(self, ctx: Context) -> bool: ...
class AclStrategy:
"""Per-app ACL, evaluated in two stages.
The EE gateway has already enforced tenancy and workspace membership
by the time this strategy runs, so AclStrategy only owns per-app ACL:
1. Subject vs access-mode compatibility (pure rule table). External-SSO
bearers belong to public-facing apps only; account bearers cover the
full set. A mismatch is an immediate deny — no IO.
2. For modes that pair with the subject, decide whether the inner
permission API must run. Only `PRIVATE` (per-app selected-user list)
requires it; the remaining modes are pass-through.
"""
_ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = {
SubjectType.ACCOUNT: frozenset(
{
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
WebAppAccessMode.PRIVATE_ALL,
WebAppAccessMode.PRIVATE,
}
),
SubjectType.EXTERNAL_SSO: frozenset(
{
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
}
),
}
_MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE})
def authorize(self, ctx: Context) -> bool:
if ctx.app is None:
return False
access_mode = self._fetch_access_mode(ctx.app.id)
if access_mode is None:
return False
if not self._subject_allowed_for_mode(ctx.must_subject_type, access_mode):
return False
if access_mode not in self._MODES_REQUIRING_INNER_CHECK:
return True
return self._inner_permission_check(ctx)
@staticmethod
def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None:
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
if settings is None:
return None
try:
return WebAppAccessMode(settings.access_mode)
except ValueError:
return None
@classmethod
def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool:
return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset())
def _inner_permission_check(self, ctx: Context) -> bool:
if ctx.app is None:
return False
user_id = self._resolve_user_id(ctx)
if user_id is None:
return False
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=user_id,
app_id=ctx.app.id,
)
@staticmethod
def _resolve_user_id(ctx: Context) -> str | None:
if ctx.subject_type == SubjectType.ACCOUNT:
return str(ctx.account_id) if ctx.account_id is not None else None
if ctx.subject_email is None:
return None
account = AccountService.get_account_by_email(db.session, ctx.subject_email)
return str(account.id) if account is not None else None
class MembershipStrategy:
"""Tenant-membership fallback.
Used when webapp-auth is disabled (CE deployment). Account-bearing
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
denied (it requires the webapp-auth surface).
"""
def authorize(self, ctx: Context) -> bool:
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
return False
if ctx.tenant is None:
return False
return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id)
def _login_as(user) -> None:
"""Set Flask-Login request user so downstream services see the caller."""
current_app.login_manager._update_request_context_with_user(user) # type:ignore
user_logged_in.send(current_app._get_current_object(), user=user) # type:ignore
class CallerMounter(Protocol):
def applies_to(self, subject_type: SubjectType) -> bool: ...
def mount(self, ctx: Context) -> None: ...
class AccountMounter:
def applies_to(self, subject_type: SubjectType) -> bool:
return subject_type == SubjectType.ACCOUNT
def mount(self, ctx: Context) -> None:
if ctx.account_id is None:
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
account = AccountService.get_account_by_id(db.session, str(ctx.account_id))
if account is None:
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
account.current_tenant = ctx.must_tenant
_login_as(account)
ctx.caller, ctx.caller_kind = account, "account"
class EndUserMounter:
def applies_to(self, subject_type: SubjectType) -> bool:
return subject_type == SubjectType.EXTERNAL_SSO
def mount(self, ctx: Context) -> None:
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
end_user = EndUserService.get_or_create_end_user_by_type(
InvokeFrom.OPENAPI,
tenant_id=ctx.tenant.id,
app_id=ctx.app.id,
user_id=ctx.subject_email,
)
_login_as(end_user)
ctx.caller, ctx.caller_kind = end_user, "end_user"

View File

@ -1,82 +0,0 @@
from __future__ import annotations
from werkzeug.exceptions import Forbidden, Unauthorized
from controllers.openapi.auth.data import AuthData
from extensions.ext_database import db
from libs.oauth_bearer import Scope, TokenType, check_workspace_membership
from services.account_service import AccountService, TenantService
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
def check_scope(data: AuthData) -> None:
if data.required_scope is None:
return
if Scope.FULL in data.scopes or data.required_scope in data.scopes:
return
raise Forbidden("insufficient_scope")
def check_membership(data: AuthData) -> None:
if data.tenant is None:
raise Unauthorized("tenant unset")
if data.account_id is None:
raise Unauthorized("account_id unset")
check_workspace_membership(
account_id=data.account_id,
tenant_id=data.tenant.id,
token_hash=data.token_hash,
membership_cache=data.tenants,
)
def check_app_access(data: AuthData) -> None:
if data.tenant is None:
return
if not TenantService.account_belongs_to_tenant(db.session, data.account_id, data.tenant.id):
raise Forbidden("subject_no_app_access")
_ALLOWED_MODES_BY_TOKEN_TYPE: dict[TokenType, frozenset[WebAppAccessMode]] = {
TokenType.OAUTH_ACCOUNT: frozenset(
{
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
WebAppAccessMode.PRIVATE_ALL,
WebAppAccessMode.PRIVATE,
}
),
TokenType.OAUTH_EXTERNAL_SSO: frozenset(
{
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
}
),
}
def check_acl(data: AuthData) -> None:
if data.app is None or data.app_access_mode is None:
raise Forbidden("app or access mode not loaded")
allowed_modes = _ALLOWED_MODES_BY_TOKEN_TYPE.get(data.token_type, frozenset())
if data.app_access_mode not in allowed_modes:
raise Forbidden("subject_not_allowed_for_access_mode")
def check_private_app_permission(data: AuthData) -> None:
if data.app is None:
raise Forbidden("app not loaded")
user_id = _resolve_user_id(data)
if user_id is None:
raise Forbidden("cannot resolve user for private app check")
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id=user_id, app_id=data.app.id):
raise Forbidden("user_not_allowed_for_private_app")
def _resolve_user_id(data: AuthData) -> str | None:
if data.token_type == TokenType.OAUTH_ACCOUNT:
return str(data.account_id) if data.account_id is not None else None
if data.external_identity is None:
return None
account = AccountService.get_account_by_email(db.session, data.external_identity.email)
return str(account.id) if account is not None else None

View File

@ -17,11 +17,11 @@ from controllers.common.errors import (
UnsupportedFileTypeError,
)
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from extensions.ext_database import db
from fields.file_fields import FileResponse
from libs.oauth_bearer import Scope
from models import Account, App
from services.file_service import FileService
@ -39,9 +39,8 @@ class AppFileUploadApi(Resource):
}
)
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
@auth_router.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, *, auth_data: AuthData):
app_model, caller, _ = auth_data.require_app_context()
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, app_model: App, caller: Account, caller_kind: str):
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:

View File

@ -17,8 +17,7 @@ from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
from controllers.common.schema import register_schema_models
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.helper import to_timestamp
@ -56,9 +55,8 @@ def _ensure_form_is_allowed_for_openapi(form) -> None:
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
class OpenApiWorkflowHumanInputFormApi(Resource):
@openapi_ns.response(200, "Form definition")
@auth_router.guard(scope=Scope.APPS_RUN)
def get(self, app_id: str, form_token: str, *, auth_data: AuthData):
app_model, caller, caller_kind = auth_data.require_app_context()
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def get(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
@ -71,9 +69,8 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
@openapi_ns.response(200, "Form submitted")
@auth_router.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, form_token: str, *, auth_data: AuthData):
app_model, caller, caller_kind = auth_data.require_app_context()
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
service = HumanInputService(db.engine)

View File

@ -17,8 +17,7 @@ from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound, UnprocessableEntity
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
@ -29,7 +28,7 @@ from core.workflow.human_input_policy import HumanInputSurface
from extensions.ext_database import db
from libs.oauth_bearer import Scope
from models.enums import CreatorUserRole
from models.model import AppMode
from models.model import App, AppMode
from repositories.factory import DifyAPIRepositoryFactory
from services.workflow_event_snapshot_service import build_workflow_event_stream
@ -37,9 +36,8 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/events")
class OpenApiWorkflowEventsApi(Resource):
@openapi_ns.response(200, "SSE event stream")
@auth_router.guard(scope=Scope.APPS_RUN)
def get(self, app_id: str, task_id: str, *, auth_data: AuthData):
app_model, caller, caller_kind = auth_data.require_app_context()
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def get(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
raise UnprocessableEntity("mode_not_supported_for_event_reconnect")

View File

@ -35,11 +35,15 @@ from controllers.openapi._models import (
WorkspaceListResponse,
WorkspaceSummaryResponse,
)
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.role_gate import require_workspace_role
from controllers.openapi.auth.surface_gate import accept_subjects
from extensions.ext_database import db
from libs.oauth_bearer import Scope, TokenType
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
SubjectType,
get_auth_ctx,
validate_bearer,
)
from models import Account, Tenant, TenantAccountJoin
from models.account import TenantAccountRole, TenantStatus
from services.account_service import AccountService, RegisterService, TenantService
@ -56,6 +60,11 @@ from services.feature_service import FeatureService
def _validate_body[M: BaseModel](model: type[M]) -> M:
"""Validate JSON body against ``model``. Validation errors → HTTP 400.
The workspace spec is explicit that bad email / unknown role payloads
are 400, not Pydantic's default 422 — handle uniformly here.
"""
body = request.get_json(silent=True) or {}
try:
return model.model_validate(body)
@ -82,6 +91,7 @@ def _load_tenant(workspace_id: str) -> Tenant:
def _load_account(account_id: object) -> Account:
"""Load the caller's Account. Missing == auth wiring bug, not user error."""
account = AccountService.get_account_by_id(db.session, str(account_id)) if account_id else None
if account is None:
raise RuntimeError("authenticated account_id has no Account row")
@ -89,6 +99,13 @@ def _load_account(account_id: object) -> Account:
def _quota_error(*, code: str, message: str, hint: str) -> Forbidden:
"""Build a 403 with envelope ``{code, message, hint}``.
CLI ``error-mapper`` reads ``message`` and ``hint`` off the wire body
verbatim — the structured envelope lets it surface remediation guidance
(e.g. "upgrade your plan") without the CLI needing to know edition
semantics.
"""
err = Forbidden(message)
err.response = make_response(
jsonify({"code": code, "message": message, "hint": hint}),
@ -98,6 +115,16 @@ def _quota_error(*, code: str, message: str, hint: str) -> Forbidden:
def _check_member_invite_quota(tenant_id: str) -> None:
"""Edition-aware member-count gate for invite.
Both branches self-disable on CE because ``FeatureService.get_features``
leaves ``billing.enabled`` and ``workspace_members.enabled`` False by
default; SaaS billing API and EE license activation are what flip them on.
Mirrors the two checks the console invite path performs (decorator at
``console/wraps.py:106`` for billing + inline at
``console/workspace/members.py:130`` for license).
"""
features = FeatureService.get_features(tenant_id)
if features.billing.enabled:
@ -121,9 +148,12 @@ def _check_member_invite_quota(tenant_id: str) -> None:
@openapi_ns.route("/workspaces")
class WorkspacesApi(Resource):
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, *, auth_data: AuthData):
rows = TenantService.get_workspaces_for_account(db.session, str(auth_data.account_id))
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
def get(self):
ctx = get_auth_ctx()
rows = TenantService.get_workspaces_for_account(db.session, str(ctx.account_id))
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
@ -131,9 +161,12 @@ class WorkspacesApi(Resource):
@openapi_ns.route("/workspaces/<string:workspace_id>")
class WorkspaceByIdApi(Resource):
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, workspace_id: str, *, auth_data: AuthData):
row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id)
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
def get(self, workspace_id: str):
ctx = get_auth_ctx()
row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id)
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
if row is None:
raise NotFound("workspace not found")
@ -152,17 +185,21 @@ class WorkspaceSwitchApi(Resource):
"""
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
@require_workspace_role()
def post(self, workspace_id: str, *, auth_data: AuthData):
account = _load_account(auth_data.account_id)
def post(self, workspace_id: str):
ctx = get_auth_ctx()
account = _load_account(ctx.account_id)
try:
TenantService.switch_tenant(account, workspace_id)
except AccountNotLinkTenantError:
# Membership existed at gate time but Tenant.status != NORMAL or
# the row was just removed — treat as not-found.
raise NotFound("workspace not found")
row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id)
row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id)
if row is None:
raise NotFound("workspace not found")
tenant, membership = row
@ -179,15 +216,20 @@ class WorkspaceMembersApi(Resource):
@openapi_ns.doc(params=query_params_from_model(MemberListQuery))
@openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__])
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
@require_workspace_role()
def get(self, workspace_id: str, *, auth_data: AuthData):
def get(self, workspace_id: str):
try:
query = MemberListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise BadRequest(str(exc))
tenant = _load_tenant(workspace_id)
# Members per workspace are bounded by SaaS plan caps (≤50) or EE
# license seats (low thousands worst-case), so we materialize and
# slice in-memory rather than push pagination into the service —
# matches how the rest of the service exposes member lists.
members = TenantService.get_tenant_members(tenant)
total = len(members)
start = (query.page - 1) * query.limit
@ -202,11 +244,13 @@ class WorkspaceMembersApi(Resource):
@openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__])
@openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__])
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def post(self, workspace_id: str, *, auth_data: AuthData):
def post(self, workspace_id: str):
payload = _validate_body(MemberInvitePayload)
inviter = _load_account(auth_data.account_id)
ctx = get_auth_ctx()
inviter = _load_account(ctx.account_id)
tenant = _load_tenant(workspace_id)
_check_member_invite_quota(str(tenant.id))
@ -253,10 +297,12 @@ class WorkspaceMemberApi(Resource):
"""
@openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__])
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
operator = _load_account(auth_data.account_id)
def delete(self, workspace_id: str, member_id: str):
ctx = get_auth_ctx()
operator = _load_account(ctx.account_id)
tenant = _load_tenant(workspace_id)
member = AccountService.get_account_by_id(db.session, member_id)
if member is None:
@ -284,11 +330,13 @@ class WorkspaceMemberRoleApi(Resource):
@openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__])
@openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__])
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
def put(self, workspace_id: str, member_id: str):
payload = _validate_body(MemberRoleUpdatePayload)
operator = _load_account(auth_data.account_id)
ctx = get_auth_ctx()
operator = _load_account(ctx.account_id)
tenant = _load_tenant(workspace_id)
member = AccountService.get_account_by_id(db.session, member_id)
if member is None:

View File

@ -27,7 +27,6 @@ from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
from core.workflow.node_factory import get_default_root_node_id
from core.workflow.nodes.agent_v2.session_cleanup_layer import build_workflow_agent_session_cleanup_layer
from core.workflow.system_variables import (
build_bootstrap_variables,
build_system_variables,
@ -240,7 +239,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
workflow_entry.graph_engine.layer(persistence_layer)
workflow_entry.graph_engine.layer(build_workflow_agent_session_cleanup_layer())
conversation_variable_layer = ConversationVariablePersistenceLayer(
ConversationVariableUpdater(session_factory.get_session_maker())
)

View File

@ -10,7 +10,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
from core.workflow.node_factory import get_default_root_node_id
from core.workflow.nodes.agent_v2.session_cleanup_layer import build_workflow_agent_session_cleanup_layer
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
from core.workflow.workflow_entry import WorkflowEntry
@ -167,7 +166,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
)
workflow_entry.graph_engine.layer(persistence_layer)
workflow_entry.graph_engine.layer(build_workflow_agent_session_cleanup_layer())
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)

View File

@ -475,7 +475,6 @@ class DifyNodeFactory(NodeFactory):
from core.workflow.nodes.agent_v2.file_tenant_validator import UploadFileTenantValidator
from core.workflow.nodes.agent_v2.output_failure_orchestrator import OutputFailureOrchestrator
from core.workflow.nodes.agent_v2.output_type_checker import PerOutputTypeChecker
from core.workflow.nodes.agent_v2.session_store import WorkflowAgentRuntimeSessionStore
return {
"binding_resolver": WorkflowAgentBindingResolver(),
@ -495,7 +494,6 @@ class DifyNodeFactory(NodeFactory):
# outputs contain no file refs.
"type_checker": PerOutputTypeChecker(file_validator=UploadFileTenantValidator()),
"failure_orchestrator": OutputFailureOrchestrator(),
"session_store": WorkflowAgentRuntimeSessionStore(),
}
return {
"strategy_resolver": self._agent_strategy_resolver,

View File

@ -1,11 +1,8 @@
from __future__ import annotations
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from agenton.compositor import CompositorSessionSnapshot
from clients.agent_backend import (
AgentBackendError,
AgentBackendHTTPError,
@ -20,14 +17,11 @@ from clients.agent_backend import (
AgentBackendStreamInternalEvent,
AgentBackendTransportError,
AgentBackendValidationError,
CleanupLayerSpec,
extract_cleanup_layer_specs,
)
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.workflow.system_variables import SystemVariableKey, get_system_text
from graphon.entities.pause_reason import SchedulingPause
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.node_events import NodeEventBase, NodeRunResult, PauseRequestedEvent, StreamCompletedEvent
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
from graphon.nodes.base.node import Node
from models.agent_config_entities import WorkflowNodeJobConfig
@ -46,14 +40,11 @@ from .runtime_request_builder import (
WorkflowAgentRuntimeRequestBuilder,
WorkflowAgentRuntimeRequestBuildError,
)
from .session_store import WorkflowAgentRuntimeSessionStore, WorkflowAgentSessionScope
if TYPE_CHECKING:
from graphon.entities import GraphInitParams
from graphon.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)
# Stage 4 §5+§7: the terminal events that `_consume_event_stream` may return.
# Stream + started events are filtered out before we yield; transport errors
@ -83,7 +74,6 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
output_adapter: WorkflowAgentOutputAdapter,
type_checker: PerOutputTypeChecker,
failure_orchestrator: OutputFailureOrchestrator,
session_store: WorkflowAgentRuntimeSessionStore | None = None,
) -> None:
super().__init__(
node_id=node_id,
@ -98,7 +88,6 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
self._output_adapter = output_adapter
self._type_checker = type_checker
self._failure_orchestrator = failure_orchestrator
self._session_store = session_store
@classmethod
def version(cls) -> str:
@ -145,17 +134,6 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
"agent_config_snapshot_id": bundle.snapshot.id,
"binding_id": bundle.binding.id,
}
session_scope = WorkflowAgentSessionScope(
tenant_id=dify_ctx.tenant_id,
app_id=dify_ctx.app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id=self._node_id,
node_execution_id=self.id,
binding_id=bundle.binding.id,
agent_id=bundle.agent.id,
agent_config_snapshot_id=bundle.snapshot.id,
)
# Stage 4 §4.1 (D-3): use effective outputs so defaults flow through both
# the backend request and the post-run type check.
@ -169,9 +147,6 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
attempt = 0
while True:
try:
session_snapshot = None
if self._session_store is not None:
session_snapshot = self._session_store.load_active_snapshot(session_scope)
runtime_request = self._runtime_request_builder.build(
WorkflowAgentRuntimeBuildContext(
dify_context=dify_ctx,
@ -184,7 +159,6 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
agent=bundle.agent,
snapshot=bundle.snapshot,
attempt=attempt,
session_snapshot=session_snapshot,
)
)
except WorkflowAgentRuntimeRequestBuildError as error:
@ -247,35 +221,9 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
)
return
if isinstance(terminal_event, AgentBackendRunPausedInternalEvent):
self._save_session_snapshot(
session_scope=session_scope,
backend_run_id=terminal_event.run_id,
snapshot=terminal_event.session_snapshot,
composition_layer_specs=extract_cleanup_layer_specs(runtime_request.request.composition),
metadata=metadata,
)
yield PauseRequestedEvent(
reason=SchedulingPause(
message=terminal_event.message
or "Agent backend run requested workflow pause for external input."
)
)
return
# Non-success terminal (failed / cancelled) skips per-output
# post-processing — the backend itself already failed. We also retire
# the local ACTIVE session row so a workflow loop back into the same
# Agent node cannot resume from a stale snapshot. The failed agent
# backend layers (suspended per ``on_exit``) are left for agent
# backend's own GC; this row will no longer be picked up by the
# workflow-terminal cleanup layer.
# Non-success terminal (failed / cancelled / paused) skips per-output
# post-processing — the backend itself already failed.
if not isinstance(terminal_event, AgentBackendRunSucceededInternalEvent):
self._mark_session_cleaned_on_failure(
session_scope=session_scope,
backend_run_id=terminal_event.run_id,
metadata=metadata,
)
yield StreamCompletedEvent(
node_run_result=self._output_adapter.build_failure_result(
event=terminal_event,
@ -286,14 +234,6 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
)
return
self._save_session_snapshot(
session_scope=session_scope,
backend_run_id=terminal_event.run_id,
snapshot=terminal_event.session_snapshot,
composition_layer_specs=extract_cleanup_layer_specs(runtime_request.request.composition),
metadata=metadata,
)
# ──── Stage 4: per-output type check ────
type_check = self._type_checker.check(
declared_outputs=effective_outputs,
@ -444,75 +384,6 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
],
}
def _save_session_snapshot(
self,
*,
session_scope: WorkflowAgentSessionScope,
backend_run_id: str,
snapshot: CompositorSessionSnapshot | None,
composition_layer_specs: list[CleanupLayerSpec],
metadata: dict[str, Any],
) -> None:
if self._session_store is None:
return
try:
self._session_store.save_active_snapshot(
scope=session_scope,
backend_run_id=backend_run_id,
snapshot=snapshot,
composition_layer_specs=composition_layer_specs,
)
agent_backend = dict(metadata.get("agent_backend") or {})
agent_backend["session_snapshot_persisted"] = snapshot is not None
metadata["agent_backend"] = agent_backend
except Exception:
logger.warning(
"Failed to persist workflow Agent runtime session snapshot: "
"tenant_id=%s workflow_run_id=%s node_id=%s binding_id=%s agent_id=%s backend_run_id=%s",
session_scope.tenant_id,
session_scope.workflow_run_id,
session_scope.node_id,
session_scope.binding_id,
session_scope.agent_id,
backend_run_id,
exc_info=True,
)
agent_backend = dict(metadata.get("agent_backend") or {})
agent_backend["session_snapshot_persisted"] = False
agent_backend["session_snapshot_persist_error"] = "workflow_agent_runtime_session_store_error"
metadata["agent_backend"] = agent_backend
def _mark_session_cleaned_on_failure(
self,
*,
session_scope: WorkflowAgentSessionScope,
backend_run_id: str,
metadata: dict[str, Any],
) -> None:
if self._session_store is None:
return
try:
self._session_store.mark_cleaned(scope=session_scope, backend_run_id=backend_run_id)
agent_backend = dict(metadata.get("agent_backend") or {})
agent_backend["session_snapshot_cleaned_on_failure"] = True
metadata["agent_backend"] = agent_backend
except Exception:
logger.warning(
"Failed to mark workflow Agent runtime session cleaned on agent run failure: "
"tenant_id=%s workflow_run_id=%s node_id=%s binding_id=%s agent_id=%s backend_run_id=%s",
session_scope.tenant_id,
session_scope.workflow_run_id,
session_scope.node_id,
session_scope.binding_id,
session_scope.agent_id,
backend_run_id,
exc_info=True,
)
agent_backend = dict(metadata.get("agent_backend") or {})
agent_backend["session_snapshot_cleaned_on_failure"] = False
agent_backend["session_snapshot_cleanup_error"] = "workflow_agent_runtime_session_store_error"
metadata["agent_backend"] = agent_backend
@staticmethod
def _patch_event_with_defaults(
event: AgentBackendRunSucceededInternalEvent,

View File

@ -4,7 +4,6 @@ from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Literal, Protocol, cast
from agenton.compositor import CompositorSessionSnapshot
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
from dify_agent.protocol import CreateRunRequest
@ -29,7 +28,6 @@ from models.agent_config_entities import (
from models.agent_config_entities import (
effective_declared_outputs as _effective_declared_outputs,
)
from models.provider_ids import ModelProviderID
from .output_failure_orchestrator import retry_idempotency_key
from .plugin_tools_builder import WorkflowAgentPluginToolsBuilder, WorkflowAgentPluginToolsBuildError
@ -68,7 +66,6 @@ class WorkflowAgentRuntimeBuildContext:
# Stage 4 §7 / D-4: 0 for the first run, then incremented per retry. Drives the
# idempotency key so the backend treats each retry as a fresh request.
attempt: int = 0
session_snapshot: CompositorSessionSnapshot | None = None
@dataclass(frozen=True, slots=True)
@ -132,14 +129,11 @@ class WorkflowAgentRuntimeRequestBuilder:
request = self._request_builder.build_for_workflow_node(
AgentBackendWorkflowNodeRunInput(
model=AgentBackendModelConfig(
plugin_id=self._plugin_daemon_plugin_id(
plugin_id=agent_soul.model.plugin_id,
model_provider=agent_soul.model.model_provider,
),
model_provider=self._plugin_daemon_provider_name(agent_soul.model.model_provider),
plugin_id=agent_soul.model.plugin_id,
model_provider=agent_soul.model.model_provider,
model=agent_soul.model.model,
credentials=self._normalize_credentials(credentials),
model_settings=agent_soul.model.model_settings,
model_settings=cast(dict[str, Any], agent_soul.model.model_settings),
),
# The execution-context layer is now the only public protocol
# carrier for Dify tenant/user/run identifiers. ``user_id`` must
@ -164,7 +158,6 @@ class WorkflowAgentRuntimeRequestBuilder:
user_prompt=user_prompt,
output=self._build_output_config(node_job.declared_outputs),
tools=tools_layer,
session_snapshot=context.session_snapshot,
idempotency_key=self._idempotency_key(context),
metadata=metadata,
)
@ -184,20 +177,6 @@ class WorkflowAgentRuntimeRequestBuilder:
return "single_step"
return "workflow_run"
@staticmethod
def _plugin_daemon_plugin_id(*, plugin_id: str, model_provider: str) -> str:
"""Return the transport plugin id expected by plugin-daemon headers."""
if plugin_id.count("/") == 1:
return plugin_id
if plugin_id:
return ModelProviderID(plugin_id).plugin_id
return ModelProviderID(model_provider).plugin_id
@staticmethod
def _plugin_daemon_provider_name(model_provider: str) -> str:
"""Return the provider name expected by plugin-daemon dispatch payloads."""
return ModelProviderID(model_provider).provider_name
@staticmethod
def _idempotency_key(context: WorkflowAgentRuntimeBuildContext) -> str:
# Stage 4 §7 / D-4: retries get distinct keys (``...:retry-{attempt}``) so

View File

@ -1,247 +0,0 @@
from __future__ import annotations
import logging
from typing import override
from clients.agent_backend import AgentBackendError, AgentBackendRunClient, AgentBackendRunRequestBuilder
from clients.agent_backend.factory import create_agent_backend_run_client
from configs import dify_config
from core.workflow.system_variables import SystemVariableKey, get_system_text
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import (
GraphEngineEvent,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
)
from .session_store import StoredWorkflowAgentSession, WorkflowAgentRuntimeSessionStore
logger = logging.getLogger(__name__)
# Upper bound on how long a cleanup-only run is allowed to settle before the
# layer gives up and leaves the row ACTIVE so it can be retried later. Cleanup
# work is mostly local agent-backend bookkeeping (no LLM inference), so 30s is
# generous; a hung backend should never block workflow termination beyond this.
_CLEANUP_WAIT_TIMEOUT_SECONDS = 30.0
class WorkflowAgentSessionCleanupLayer(GraphEngineLayer):
"""Retires workflow Agent session snapshots when a workflow reaches a terminal state.
Implementation notes — there are two failure modes the cleanup path has to
avoid simultaneously:
1. The agenton compositor on the agent-backend side validates the cleanup
request's session snapshot against the replayed composition before
running any lifecycle hook. If the snapshot's layer names diverge from
the composition, the run fails asynchronously with ``run_failed`` — but
the initial ``POST /runs`` already returned 202, so the API side has no
visibility of the failure unless it waits for terminal status. The
``composition_layer_specs`` persistence in A.1A.4 plus the
``_filter_snapshot_to_specs`` shape in ``build_cleanup_request`` keeps
the two name lists in sync.
2. The current agent backend's ``runner.py::_run_agent`` always invokes
``run.get_layer("llm")`` and the structured-output / history validators
before exiting any slot — there is no ``purpose: "cleanup"`` branch
yet. A truly cleanup-only request (no LLM layer) therefore still
crashes inside the runner with ``Layer 'llm' is not defined in this
compositor run.``. Until the backend grows a cleanup-only purpose,
this layer **does not issue an HTTP cleanup run**: it simply retires
the local snapshot row so stale state cannot be re-resumed, and lets
the agent backend's own retention TTL release the suspended layers.
The HTTP-cleanup machinery (``build_cleanup_request`` + ``wait_run``) is
intentionally still wired into the request builder + integration tests so
that when the agent backend supports cleanup runs we can flip the switch
here with a one-line change (see ``_HTTP_CLEANUP_SUPPORTED``).
"""
# Flip to True once dify-agent's runner has a ``purpose=cleanup`` branch
# that skips the LLM/output/user-prompt invariants. Until then we only
# update the local row; the spec list is still persisted so the future
# HTTP cleanup path has everything it needs.
_HTTP_CLEANUP_SUPPORTED: bool = False
_TERMINAL_EVENTS = (
GraphRunSucceededEvent,
GraphRunPartialSucceededEvent,
GraphRunFailedEvent,
GraphRunAbortedEvent,
)
def __init__(
self,
*,
session_store: WorkflowAgentRuntimeSessionStore,
request_builder: AgentBackendRunRequestBuilder,
agent_backend_client: AgentBackendRunClient | None,
cleanup_wait_timeout_seconds: float = _CLEANUP_WAIT_TIMEOUT_SECONDS,
) -> None:
super().__init__()
self._session_store = session_store
self._request_builder = request_builder
self._agent_backend_client = agent_backend_client
self._cleanup_wait_timeout_seconds = cleanup_wait_timeout_seconds
@override
def on_graph_start(self) -> None:
return
@override
def on_event(self, event: GraphEngineEvent) -> None:
if not isinstance(event, self._TERMINAL_EVENTS):
return
workflow_run_id = get_system_text(
self.graph_runtime_state.variable_pool,
SystemVariableKey.WORKFLOW_EXECUTION_ID,
)
if not workflow_run_id:
logger.warning("Skipping workflow Agent session cleanup: workflow_run_id is missing.")
return
for stored_session in self._session_store.list_active_sessions(workflow_run_id=workflow_run_id):
self._cleanup_session(stored_session)
@override
def on_graph_end(self, error: Exception | None) -> None:
return
def _cleanup_session(self, stored_session: StoredWorkflowAgentSession) -> None:
scope = stored_session.scope
if not self._HTTP_CLEANUP_SUPPORTED:
# Agent backend has no cleanup-only run mode yet (see class
# docstring). Retire the local row so future re-entries do not
# resume from stale state, and let the backend's retention TTL
# release the suspended layers on its own schedule.
logger.info(
"Workflow Agent session retired locally; HTTP cleanup is disabled "
"until the agent backend supports a cleanup-only run mode. "
"workflow_run_id=%s node_id=%s binding_id=%s agent_id=%s previous_run_id=%s",
scope.workflow_run_id,
scope.node_id,
scope.binding_id,
scope.agent_id,
stored_session.backend_run_id,
)
self._session_store.mark_cleaned(scope=scope, backend_run_id=stored_session.backend_run_id)
return
if self._agent_backend_client is None:
# HTTP cleanup was enabled by the caller but no client was wired
# in (e.g. the API runs without AGENT_BACKEND_BASE_URL configured).
# Leave the row ACTIVE so an operator restart with proper config
# can drive the cleanup; do not silently retire it.
logger.warning(
"Skipping Agent backend cleanup: HTTP cleanup is enabled but no agent "
"backend client is wired in. workflow_run_id=%s node_id=%s agent_id=%s",
scope.workflow_run_id,
scope.node_id,
scope.agent_id,
)
return
if not stored_session.composition_layer_specs:
# Sessions persisted before A.1 landed do not carry the spec list,
# so we cannot replay a valid cleanup composition. Leave the row
# ACTIVE and warn so the absence shows up in observability rather
# than being silently swallowed by a doomed cleanup run.
logger.warning(
"Skipping Agent backend cleanup: no composition_layer_specs persisted. "
"workflow_run_id=%s node_id=%s agent_id=%s",
scope.workflow_run_id,
scope.node_id,
scope.agent_id,
)
return
request = self._request_builder.build_cleanup_request(
session_snapshot=stored_session.session_snapshot,
composition_layer_specs=stored_session.composition_layer_specs,
idempotency_key=f"{scope.workflow_run_id}:{scope.node_id}:{scope.binding_id}:agent-session-cleanup",
metadata={
"tenant_id": scope.tenant_id,
"app_id": scope.app_id,
"workflow_id": scope.workflow_id,
"workflow_run_id": scope.workflow_run_id,
"node_id": scope.node_id,
"node_execution_id": scope.node_execution_id,
"binding_id": scope.binding_id,
"agent_id": scope.agent_id,
"agent_config_snapshot_id": scope.agent_config_snapshot_id,
"previous_agent_backend_run_id": stored_session.backend_run_id,
},
)
try:
response = self._agent_backend_client.create_run(request)
except AgentBackendError:
logger.warning(
"Agent backend session cleanup request failed: workflow_run_id=%s node_id=%s agent_id=%s",
scope.workflow_run_id,
scope.node_id,
scope.agent_id,
exc_info=True,
)
return
try:
status_response = self._agent_backend_client.wait_run(
response.run_id, timeout_seconds=self._cleanup_wait_timeout_seconds
)
except AgentBackendError:
logger.warning(
"Agent backend session cleanup wait_run failed: "
"workflow_run_id=%s node_id=%s agent_id=%s cleanup_run_id=%s",
scope.workflow_run_id,
scope.node_id,
scope.agent_id,
response.run_id,
exc_info=True,
)
return
if status_response.status != "succeeded":
logger.warning(
"Agent backend session cleanup did not succeed: status=%s error=%s "
"workflow_run_id=%s node_id=%s agent_id=%s cleanup_run_id=%s",
status_response.status,
status_response.error,
scope.workflow_run_id,
scope.node_id,
scope.agent_id,
response.run_id,
)
return
self._session_store.mark_cleaned(scope=scope, backend_run_id=response.run_id)
def build_workflow_agent_session_cleanup_layer() -> WorkflowAgentSessionCleanupLayer:
"""Wire the cleanup layer with the standard production dependencies.
The agent backend client is constructed only when ``AGENT_BACKEND_BASE_URL``
is configured (or the deterministic fake is explicitly enabled). When
neither is set — for example unit tests that bring up the workflow runner
without an Agent node — we pass ``None`` so the layer stays harmless. With
``_HTTP_CLEANUP_SUPPORTED = False`` the local-retire branch never touches
the client anyway, but keeping it ``None`` avoids importing httpx and lets
test harnesses skip backend configuration.
"""
agent_backend_client: AgentBackendRunClient | None
if dify_config.AGENT_BACKEND_USE_FAKE or dify_config.AGENT_BACKEND_BASE_URL:
agent_backend_client = create_agent_backend_run_client(
base_url=dify_config.AGENT_BACKEND_BASE_URL,
use_fake=dify_config.AGENT_BACKEND_USE_FAKE,
fake_scenario=dify_config.AGENT_BACKEND_FAKE_SCENARIO,
)
else:
agent_backend_client = None
return WorkflowAgentSessionCleanupLayer(
session_store=WorkflowAgentRuntimeSessionStore(),
request_builder=AgentBackendRunRequestBuilder(),
agent_backend_client=agent_backend_client,
)

View File

@ -1,179 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass, field
from agenton.compositor import CompositorSessionSnapshot
from pydantic import TypeAdapter
from sqlalchemy import select
from clients.agent_backend.request_builder import CleanupLayerSpec
from core.db.session_factory import session_factory
from libs.datetime_utils import naive_utc_now
from models.agent import (
WorkflowAgentRuntimeSession,
WorkflowAgentRuntimeSessionStatus,
)
_SPECS_ADAPTER: TypeAdapter[list[CleanupLayerSpec]] = TypeAdapter(list[CleanupLayerSpec])
def _serialize_specs(specs: list[CleanupLayerSpec]) -> str:
return _SPECS_ADAPTER.dump_json(specs).decode()
def _deserialize_specs(value: str | None) -> list[CleanupLayerSpec]:
if not value:
return []
return _SPECS_ADAPTER.validate_json(value)
@dataclass(frozen=True, slots=True)
class WorkflowAgentSessionScope:
tenant_id: str
app_id: str
workflow_id: str
workflow_run_id: str | None
node_id: str
node_execution_id: str
binding_id: str
agent_id: str
agent_config_snapshot_id: str
@dataclass(frozen=True, slots=True)
class StoredWorkflowAgentSession:
scope: WorkflowAgentSessionScope
session_snapshot: CompositorSessionSnapshot
backend_run_id: str | None
composition_layer_specs: list[CleanupLayerSpec] = field(default_factory=list)
class WorkflowAgentRuntimeSessionStore:
"""Stores Agent backend session snapshots for workflow Agent node re-entry."""
def load_active_snapshot(self, scope: WorkflowAgentSessionScope) -> CompositorSessionSnapshot | None:
if scope.workflow_run_id is None:
return None
with session_factory.create_session() as session:
row = session.scalar(
select(WorkflowAgentRuntimeSession).where(
WorkflowAgentRuntimeSession.tenant_id == scope.tenant_id,
WorkflowAgentRuntimeSession.workflow_run_id == scope.workflow_run_id,
WorkflowAgentRuntimeSession.node_id == scope.node_id,
WorkflowAgentRuntimeSession.binding_id == scope.binding_id,
WorkflowAgentRuntimeSession.agent_id == scope.agent_id,
WorkflowAgentRuntimeSession.status == WorkflowAgentRuntimeSessionStatus.ACTIVE,
)
)
if row is None:
return None
return CompositorSessionSnapshot.model_validate_json(row.session_snapshot)
def list_active_sessions(self, *, workflow_run_id: str) -> list[StoredWorkflowAgentSession]:
with session_factory.create_session() as session:
rows = session.scalars(
select(WorkflowAgentRuntimeSession).where(
WorkflowAgentRuntimeSession.workflow_run_id == workflow_run_id,
WorkflowAgentRuntimeSession.status == WorkflowAgentRuntimeSessionStatus.ACTIVE,
)
).all()
return [
StoredWorkflowAgentSession(
scope=WorkflowAgentSessionScope(
tenant_id=row.tenant_id,
app_id=row.app_id,
workflow_id=row.workflow_id,
workflow_run_id=row.workflow_run_id,
node_id=row.node_id,
node_execution_id=row.node_execution_id or "",
binding_id=row.binding_id,
agent_id=row.agent_id,
agent_config_snapshot_id=row.agent_config_snapshot_id,
),
session_snapshot=CompositorSessionSnapshot.model_validate_json(row.session_snapshot),
backend_run_id=row.backend_run_id,
composition_layer_specs=_deserialize_specs(row.composition_layer_specs),
)
for row in rows
]
def save_active_snapshot(
self,
*,
scope: WorkflowAgentSessionScope,
backend_run_id: str,
snapshot: CompositorSessionSnapshot | None,
composition_layer_specs: list[CleanupLayerSpec],
) -> None:
if scope.workflow_run_id is None or snapshot is None:
return
snapshot_json = snapshot.model_dump_json()
specs_json = _serialize_specs(composition_layer_specs)
with session_factory.create_session() as session:
row = session.scalar(
select(WorkflowAgentRuntimeSession).where(
WorkflowAgentRuntimeSession.tenant_id == scope.tenant_id,
WorkflowAgentRuntimeSession.workflow_run_id == scope.workflow_run_id,
WorkflowAgentRuntimeSession.node_id == scope.node_id,
WorkflowAgentRuntimeSession.binding_id == scope.binding_id,
WorkflowAgentRuntimeSession.agent_id == scope.agent_id,
)
)
if row is None:
row = WorkflowAgentRuntimeSession(
tenant_id=scope.tenant_id,
app_id=scope.app_id,
workflow_id=scope.workflow_id,
workflow_run_id=scope.workflow_run_id,
node_id=scope.node_id,
node_execution_id=scope.node_execution_id,
binding_id=scope.binding_id,
agent_id=scope.agent_id,
agent_config_snapshot_id=scope.agent_config_snapshot_id,
backend_run_id=backend_run_id,
session_snapshot=snapshot_json,
composition_layer_specs=specs_json,
status=WorkflowAgentRuntimeSessionStatus.ACTIVE,
)
session.add(row)
else:
row.node_execution_id = scope.node_execution_id
row.agent_config_snapshot_id = scope.agent_config_snapshot_id
row.backend_run_id = backend_run_id
row.session_snapshot = snapshot_json
row.composition_layer_specs = specs_json
row.status = WorkflowAgentRuntimeSessionStatus.ACTIVE
row.cleaned_at = None
session.commit()
def mark_cleaned(self, *, scope: WorkflowAgentSessionScope, backend_run_id: str | None = None) -> None:
if scope.workflow_run_id is None:
return
with session_factory.create_session() as session:
row = session.scalar(
select(WorkflowAgentRuntimeSession).where(
WorkflowAgentRuntimeSession.tenant_id == scope.tenant_id,
WorkflowAgentRuntimeSession.workflow_run_id == scope.workflow_run_id,
WorkflowAgentRuntimeSession.node_id == scope.node_id,
WorkflowAgentRuntimeSession.binding_id == scope.binding_id,
WorkflowAgentRuntimeSession.agent_id == scope.agent_id,
WorkflowAgentRuntimeSession.status == WorkflowAgentRuntimeSessionStatus.ACTIVE,
)
)
if row is None:
return
if backend_run_id is not None:
row.backend_run_id = backend_run_id
row.status = WorkflowAgentRuntimeSessionStatus.CLEANED
row.cleaned_at = naive_utc_now()
session.commit()
__all__ = [
"StoredWorkflowAgentSession",
"WorkflowAgentRuntimeSessionStore",
"WorkflowAgentSessionScope",
]

View File

@ -1,6 +1,5 @@
import posixpath
from collections.abc import Generator
from typing import override
import oss2 as aliyun_s3
@ -30,11 +29,9 @@ class AliyunOssStorage(BaseStorage):
cloudbox_id=dify_config.ALIYUN_CLOUDBOX_ID,
)
@override
def save(self, filename, data):
self.client.put_object(self.__wrapper_folder_filename(filename), data)
@override
def load_once(self, filename: str) -> bytes:
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
data = obj.read()
@ -42,21 +39,17 @@ class AliyunOssStorage(BaseStorage):
return b""
return data
@override
def load_stream(self, filename: str) -> Generator:
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
while chunk := obj.read(4096):
yield chunk
@override
def download(self, filename: str, target_filepath):
self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath)
@override
def exists(self, filename: str):
return self.client.object_exists(self.__wrapper_folder_filename(filename))
@override
def delete(self, filename: str):
self.client.delete_object(self.__wrapper_folder_filename(filename))

View File

@ -1,6 +1,5 @@
import logging
from collections.abc import Generator
from typing import override
import boto3
from botocore.client import Config
@ -49,11 +48,9 @@ class AwsS3Storage(BaseStorage):
# other error, raise exception
raise
@override
def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
@override
def load_once(self, filename: str) -> bytes:
try:
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
@ -64,7 +61,6 @@ class AwsS3Storage(BaseStorage):
raise
return data
@override
def load_stream(self, filename: str) -> Generator:
try:
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
@ -77,11 +73,9 @@ class AwsS3Storage(BaseStorage):
else:
raise
@override
def download(self, filename, target_filepath):
self.client.download_file(self.bucket_name, filename, target_filepath)
@override
def exists(self, filename):
try:
self.client.head_object(Bucket=self.bucket_name, Key=filename)
@ -89,6 +83,5 @@ class AwsS3Storage(BaseStorage):
except:
return False
@override
def delete(self, filename: str):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)

View File

@ -1,6 +1,5 @@
from collections.abc import Generator
from datetime import timedelta
from typing import override
from azure.identity import ChainedTokenCredential, DefaultAzureCredential
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
@ -27,7 +26,6 @@ class AzureBlobStorage(BaseStorage):
else:
self.credential = None
@override
def save(self, filename, data):
if not self.bucket_name:
return
@ -36,7 +34,6 @@ class AzureBlobStorage(BaseStorage):
blob_container = client.get_container_client(container=self.bucket_name)
blob_container.upload_blob(filename, data)
@override
def load_once(self, filename: str) -> bytes:
if not self.bucket_name:
raise FileNotFoundError("Azure bucket name is not configured.")
@ -49,7 +46,6 @@ class AzureBlobStorage(BaseStorage):
raise TypeError(f"Expected bytes from blob.readall(), got {type(data).__name__}")
return data
@override
def load_stream(self, filename: str) -> Generator:
if not self.bucket_name:
raise FileNotFoundError("Azure bucket name is not configured.")
@ -59,7 +55,6 @@ class AzureBlobStorage(BaseStorage):
blob_data = blob.download_blob()
yield from blob_data.chunks()
@override
def download(self, filename, target_filepath):
if not self.bucket_name:
return
@ -71,7 +66,6 @@ class AzureBlobStorage(BaseStorage):
blob_data = blob.download_blob()
blob_data.readinto(my_blob)
@override
def exists(self, filename):
if not self.bucket_name:
return False
@ -81,7 +75,6 @@ class AzureBlobStorage(BaseStorage):
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
return blob.exists()
@override
def delete(self, filename: str):
if not self.bucket_name:
return

View File

@ -1,7 +1,6 @@
import base64
import hashlib
from collections.abc import Generator
from typing import override
from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration
@ -27,7 +26,6 @@ class BaiduObsStorage(BaseStorage):
self.client = BosClient(config=client_config)
@override
def save(self, filename, data):
md5 = hashlib.md5()
md5.update(data)
@ -36,29 +34,24 @@ class BaiduObsStorage(BaseStorage):
bucket_name=self.bucket_name, key=filename, data=data, content_length=len(data), content_md5=content_md5
)
@override
def load_once(self, filename: str) -> bytes:
response = self.client.get_object(bucket_name=self.bucket_name, key=filename)
data: bytes = response.data.read()
return data
@override
def load_stream(self, filename: str) -> Generator:
response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data
while chunk := response.read(4096):
yield chunk
@override
def download(self, filename, target_filepath):
self.client.get_object_to_file(bucket_name=self.bucket_name, key=filename, file_name=target_filepath)
@override
def exists(self, filename):
res = self.client.get_object_meta_data(bucket_name=self.bucket_name, key=filename)
if res is None:
return False
return True
@override
def delete(self, filename: str):
self.client.delete_object(bucket_name=self.bucket_name, key=filename)

View File

@ -10,7 +10,7 @@ import tempfile
from collections.abc import Generator
from io import BytesIO
from pathlib import Path
from typing import Any, override
from typing import Any
import clickzetta
from pydantic import BaseModel, model_validator
@ -251,7 +251,6 @@ class ClickZettaVolumeStorage(BaseStorage):
# Don't raise exception, let the operation continue
# The table might exist but not be visible due to permissions
@override
def save(self, filename: str, data: bytes):
"""Save data to ClickZetta Volume.
@ -305,7 +304,6 @@ class ClickZettaVolumeStorage(BaseStorage):
# Clean up temporary file
Path(temp_file_path).unlink(missing_ok=True)
@override
def load_once(self, filename: str) -> bytes:
"""Load file content from ClickZetta Volume.
@ -366,7 +364,6 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.debug("File %s loaded from ClickZetta Volume", filename)
return content
@override
def load_stream(self, filename: str) -> Generator:
"""Load file as stream from ClickZetta Volume.
@ -385,7 +382,6 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
@override
def download(self, filename: str, target_filepath: str):
"""Download file from ClickZetta Volume to local path.
@ -399,7 +395,6 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
@override
def exists(self, filename: str) -> bool:
"""Check if file exists in ClickZetta Volume.
@ -441,7 +436,6 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.warning("Error checking file existence for %s: %s", filename, e)
return False
@override
def delete(self, filename: str):
"""Delete file from ClickZetta Volume.
@ -478,7 +472,6 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.debug("File %s deleted from ClickZetta Volume", filename)
@override
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
"""Scan files and directories in ClickZetta Volume.

View File

@ -1,7 +1,7 @@
import base64
import io
from collections.abc import Generator
from typing import Any, override
from typing import Any
from google.cloud import storage as google_cloud_storage # type: ignore
from pydantic import TypeAdapter
@ -29,14 +29,12 @@ class GoogleCloudStorage(BaseStorage):
else:
self.client = google_cloud_storage.Client()
@override
def save(self, filename, data):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(filename)
with io.BytesIO(data) as stream:
blob.upload_from_file(stream)
@override
def load_once(self, filename: str) -> bytes:
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
@ -45,7 +43,6 @@ class GoogleCloudStorage(BaseStorage):
data: bytes = blob.download_as_bytes()
return data
@override
def load_stream(self, filename: str) -> Generator:
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
@ -55,7 +52,6 @@ class GoogleCloudStorage(BaseStorage):
while chunk := blob_stream.read(4096):
yield chunk
@override
def download(self, filename, target_filepath):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
@ -63,13 +59,11 @@ class GoogleCloudStorage(BaseStorage):
raise FileNotFoundError("File not found")
blob.download_to_filename(target_filepath)
@override
def exists(self, filename):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(filename)
return blob.exists()
@override
def delete(self, filename: str):
bucket = self.client.get_bucket(self.bucket_name)
bucket.delete_blob(filename)

View File

@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import override
from obs import ObsClient
@ -21,33 +20,27 @@ class HuaweiObsStorage(BaseStorage):
path_style=dify_config.HUAWEI_OBS_PATH_STYLE,
)
@override
def save(self, filename, data):
self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data)
@override
def load_once(self, filename: str) -> bytes:
data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
return data
@override
def load_stream(self, filename: str) -> Generator:
response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response
while chunk := response.read(4096):
yield chunk
@override
def download(self, filename, target_filepath):
self.client.getObject(bucketName=self.bucket_name, objectKey=filename, downloadPath=target_filepath)
@override
def exists(self, filename):
res = self._get_meta(filename)
if res is None:
return False
return True
@override
def delete(self, filename: str):
self.client.deleteObject(bucketName=self.bucket_name, objectKey=filename)

View File

@ -2,7 +2,7 @@ import logging
import os
from collections.abc import Generator
from pathlib import Path
from typing import Any, override
from typing import Any
import opendal
from dotenv import dotenv_values
@ -41,12 +41,10 @@ class OpenDALStorage(BaseStorage):
logger.debug("opendal operator created with scheme %s", scheme)
logger.debug("added retry layer to opendal operator")
@override
def save(self, filename: str, data: bytes):
self.op.write(path=filename, bs=data)
logger.debug("file %s saved", filename)
@override
def load_once(self, filename: str) -> bytes:
if not self.exists(filename):
raise FileNotFoundError("File not found")
@ -55,7 +53,6 @@ class OpenDALStorage(BaseStorage):
logger.debug("file %s loaded", filename)
return content
@override
def load_stream(self, filename: str) -> Generator:
if not self.exists(filename):
raise FileNotFoundError("File not found")
@ -70,7 +67,6 @@ class OpenDALStorage(BaseStorage):
yield chunk
logger.debug("file %s loaded as stream", filename)
@override
def download(self, filename: str, target_filepath: str):
if not self.exists(filename):
raise FileNotFoundError("File not found")
@ -78,11 +74,9 @@ class OpenDALStorage(BaseStorage):
Path(target_filepath).write_bytes(self.op.read(path=filename))
logger.debug("file %s downloaded to %s", filename, target_filepath)
@override
def exists(self, filename: str) -> bool:
return self.op.exists(path=filename)
@override
def delete(self, filename: str):
if self.exists(filename):
self.op.delete(path=filename)
@ -90,7 +84,6 @@ class OpenDALStorage(BaseStorage):
return
logger.debug("file %s not found, skip delete", filename)
@override
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
if not self.exists(path):
raise FileNotFoundError("Path not found")

View File

@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import override
import boto3
from botocore.exceptions import ClientError
@ -23,11 +22,9 @@ class OracleOCIStorage(BaseStorage):
region_name=dify_config.OCI_REGION,
)
@override
def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
@override
def load_once(self, filename: str) -> bytes:
try:
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
@ -38,7 +35,6 @@ class OracleOCIStorage(BaseStorage):
raise
return data
@override
def load_stream(self, filename: str) -> Generator:
try:
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
@ -49,11 +45,9 @@ class OracleOCIStorage(BaseStorage):
else:
raise
@override
def download(self, filename, target_filepath):
self.client.download_file(self.bucket_name, filename, target_filepath)
@override
def exists(self, filename):
try:
self.client.head_object(Bucket=self.bucket_name, Key=filename)
@ -61,6 +55,5 @@ class OracleOCIStorage(BaseStorage):
except:
return False
@override
def delete(self, filename: str):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)

View File

@ -1,7 +1,6 @@
import io
from collections.abc import Generator
from pathlib import Path
from typing import override
from supabase import Client
@ -29,35 +28,29 @@ class SupabaseStorage(BaseStorage):
if not self.bucket_exists():
self.client.storage.create_bucket(id=id, name=bucket_name)
@override
def save(self, filename, data):
self.client.storage.from_(self.bucket_name).upload(filename, data)
@override
def load_once(self, filename: str) -> bytes:
content: bytes = self.client.storage.from_(self.bucket_name).download(filename)
return content
@override
def load_stream(self, filename: str) -> Generator:
result = self.client.storage.from_(self.bucket_name).download(filename)
byte_stream = io.BytesIO(result)
while chunk := byte_stream.read(4096): # Read in chunks of 4KB
yield chunk
@override
def download(self, filename, target_filepath):
result = self.client.storage.from_(self.bucket_name).download(filename)
Path(target_filepath).write_bytes(result)
@override
def exists(self, filename):
result = self.client.storage.from_(self.bucket_name).list(path=filename)
if len(result) > 0:
return True
return False
@override
def delete(self, filename: str):
self.client.storage.from_(self.bucket_name).remove([filename])

View File

@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import override
from qcloud_cos import CosConfig, CosS3Client
@ -30,29 +29,23 @@ class TencentCosStorage(BaseStorage):
)
self.client = CosS3Client(config)
@override
def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename)
@override
def load_once(self, filename: str) -> bytes:
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read()
return data
@override
def load_stream(self, filename: str) -> Generator:
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response["Body"].get_stream(chunk_size=4096)
@override
def download(self, filename, target_filepath):
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
response["Body"].get_stream_to_file(target_filepath)
@override
def exists(self, filename):
return self.client.object_exists(Bucket=self.bucket_name, Key=filename)
@override
def delete(self, filename: str):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)

View File

@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import override
import tos
@ -28,13 +27,11 @@ class VolcengineTosStorage(BaseStorage):
region=dify_config.VOLCENGINE_TOS_REGION,
)
@override
def save(self, filename, data):
if not self.bucket_name:
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
self.client.put_object(bucket=self.bucket_name, key=filename, content=data)
@override
def load_once(self, filename: str) -> bytes:
if not self.bucket_name:
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
@ -43,7 +40,6 @@ class VolcengineTosStorage(BaseStorage):
raise TypeError(f"Expected bytes, got {type(data).__name__}")
return data
@override
def load_stream(self, filename: str) -> Generator:
if not self.bucket_name:
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
@ -51,13 +47,11 @@ class VolcengineTosStorage(BaseStorage):
while chunk := response.read(4096):
yield chunk
@override
def download(self, filename, target_filepath):
if not self.bucket_name:
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath)
@override
def exists(self, filename):
if not self.bucket_name:
return False
@ -66,7 +60,6 @@ class VolcengineTosStorage(BaseStorage):
return False
return True
@override
def delete(self, filename: str):
if not self.bucket_name:
return

View File

@ -43,11 +43,6 @@ class SubjectType(StrEnum):
EXTERNAL_SSO = "external_sso"
class TokenType(StrEnum):
OAUTH_ACCOUNT = "oauth_account"
OAUTH_EXTERNAL_SSO = "oauth_external_sso"
class Scope(StrEnum):
"""Catalog of bearer scopes recognised by the openapi surface.
@ -60,8 +55,6 @@ class Scope(StrEnum):
APPS_READ = "apps:read"
APPS_READ_PERMITTED_EXTERNAL = "apps:read:permitted-external"
APPS_RUN = "apps:run"
WORKSPACE_READ = "workspace:read"
WORKSPACE_WRITE = "workspace:write"
class Accepts(StrEnum):
@ -84,7 +77,7 @@ _SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
class AuthContext:
"""Per-request identity published via :data:`_auth_ctx_var`
(see :func:`set_auth_ctx` / :func:`get_auth_ctx`). ``scopes`` /
``subject_type`` / ``token_type`` come from the TokenKind, not the DB —
``subject_type`` / ``source`` come from the TokenKind, not the DB —
corrupt rows can't elevate scope.
`verified_tenants` is a snapshot of the Layer-0 verdict cache at
@ -99,7 +92,7 @@ class AuthContext:
client_id: str | None
scopes: frozenset[Scope]
token_id: uuid.UUID
token_type: TokenType
source: str
expires_at: datetime | None
token_hash: str
verified_tenants: dict[str, bool] = field(default_factory=dict)
@ -187,7 +180,7 @@ class TokenKind:
prefix: str
subject_type: SubjectType
scopes: frozenset[Scope]
token_type: TokenType
source: str
resolver: Resolver
def matches(self, token: str) -> bool:
@ -298,7 +291,7 @@ class BearerAuthenticator:
client_id=row.client_id,
scopes=kind.scopes,
token_id=row.token_id,
token_type=kind.token_type,
source=kind.source,
expires_at=row.expires_at,
token_hash=token_hash,
verified_tenants=dict(row.verified_tenants),
@ -490,7 +483,7 @@ def check_workspace_membership(
account_id: uuid.UUID | str,
tenant_id: str,
token_hash: str,
membership_cache: dict[str, bool],
cached_verdicts: dict[str, bool],
) -> None:
"""Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow.
@ -499,7 +492,7 @@ def check_workspace_membership(
short-circuiting on EE / SSO subjects before invoking — this function
runs the membership + active-status checks unconditionally.
"""
cached = membership_cache.get(tenant_id)
cached = cached_verdicts.get(tenant_id)
if cached is True:
return
if cached is False:
@ -537,7 +530,7 @@ def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None:
account_id=ctx.account_id,
tenant_id=tenant_id,
token_hash=ctx.token_hash,
membership_cache=ctx.verified_tenants,
cached_verdicts=ctx.verified_tenants,
)
@ -671,14 +664,14 @@ def build_registry(session_factory, redis_client) -> TokenKindRegistry:
prefix=account.prefix,
subject_type=account.subject_type,
scopes=account.scopes,
token_type=TokenType.OAUTH_ACCOUNT,
source="oauth_account",
resolver=oauth.for_account(),
),
TokenKind(
prefix=external.prefix,
subject_type=external.subject_type,
scopes=external.scopes,
token_type=TokenType.OAUTH_EXTERNAL_SSO,
source="oauth_external_sso",
resolver=oauth.for_external_sso(),
),
]

View File

@ -1,90 +0,0 @@
"""add workflow agent runtime sessions
Revision ID: 7885bd53f9a9
Revises: d4a5e1f3c9b7
Create Date: 2026-05-27 09:53:54.711805
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = "7885bd53f9a9"
down_revision = "d4a5e1f3c9b7"
branch_labels = None
depends_on = None
def _is_pg() -> bool:
return op.get_bind().dialect.name == "postgresql"
def _uuid_column(name: str, *, nullable: bool = False, primary_key: bool = False) -> sa.Column:
"""Match the ``uuidv7()`` default that other tables on Postgres rely on,
while staying portable on MySQL where the ORM supplies the id."""
kwargs: dict[str, object] = {"nullable": nullable, "primary_key": primary_key}
if primary_key and _is_pg():
kwargs["server_default"] = sa.text("uuidv7()")
return sa.Column(name, models.types.StringUUID(), **kwargs)
def upgrade() -> None:
op.create_table(
"workflow_agent_runtime_sessions",
_uuid_column("id", primary_key=True),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("app_id", models.types.StringUUID(), nullable=False),
sa.Column("workflow_id", models.types.StringUUID(), nullable=False),
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=False),
sa.Column("node_id", sa.String(length=255), nullable=False),
sa.Column("node_execution_id", sa.String(length=255), nullable=True),
sa.Column("binding_id", models.types.StringUUID(), nullable=False),
sa.Column("agent_id", models.types.StringUUID(), nullable=False),
sa.Column("agent_config_snapshot_id", models.types.StringUUID(), nullable=False),
sa.Column("backend_run_id", sa.String(length=255), nullable=True),
sa.Column("session_snapshot", models.types.LongText(), nullable=False),
# MySQL rejects ``server_default`` on TEXT/BLOB columns. The JSON
# payload is always populated at the ORM layer via
# ``WorkflowAgentRuntimeSessionStore.save_active_snapshot`` so the
# missing DB-level default cannot leave new rows uninitialized.
sa.Column("composition_layer_specs", models.types.LongText(), nullable=False),
sa.Column(
"status",
sa.String(length=32),
server_default=sa.text("'active'"),
nullable=False,
),
sa.Column("cleaned_at", sa.DateTime(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("workflow_agent_runtime_session_pkey")),
sa.UniqueConstraint(
"tenant_id",
"workflow_run_id",
"node_id",
"binding_id",
"agent_id",
name=op.f("workflow_agent_runtime_session_scope_unique"),
),
)
with op.batch_alter_table("workflow_agent_runtime_sessions", schema=None) as batch_op:
batch_op.create_index(
"workflow_agent_runtime_session_lookup_idx",
["tenant_id", "workflow_run_id", "node_id", "status"],
unique=False,
)
batch_op.create_index(
"workflow_agent_runtime_session_backend_run_idx",
["backend_run_id"],
unique=False,
)
def downgrade() -> None:
with op.batch_alter_table("workflow_agent_runtime_sessions", schema=None) as batch_op:
batch_op.drop_index("workflow_agent_runtime_session_backend_run_idx")
batch_op.drop_index("workflow_agent_runtime_session_lookup_idx")
op.drop_table("workflow_agent_runtime_sessions")

View File

@ -20,8 +20,6 @@ from .agent import (
AgentStatus,
WorkflowAgentBindingType,
WorkflowAgentNodeBinding,
WorkflowAgentRuntimeSession,
WorkflowAgentRuntimeSessionStatus,
)
from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from .comment import (
@ -237,8 +235,6 @@ __all__ = [
"Workflow",
"WorkflowAgentBindingType",
"WorkflowAgentNodeBinding",
"WorkflowAgentRuntimeSession",
"WorkflowAgentRuntimeSessionStatus",
"WorkflowAppLog",
"WorkflowAppLogCreatedFrom",
"WorkflowArchiveLog",

View File

@ -92,15 +92,6 @@ class WorkflowAgentBindingType(StrEnum):
INLINE_AGENT = "inline_agent"
class WorkflowAgentRuntimeSessionStatus(StrEnum):
"""Lifecycle state of an Agent backend session snapshot owned by a workflow run."""
# Snapshot can be reused by a later Agent run in the same workflow run.
ACTIVE = "active"
# Snapshot has been retired and must not be submitted to Agent backend again.
CLEANED = "cleaned"
class Agent(DefaultFieldsMixin, Base):
"""Workspace-scoped Agent identity used by Agent Roster and workflow-only agents."""
@ -282,56 +273,3 @@ class WorkflowAgentNodeBinding(DefaultFieldsMixin, Base):
if isinstance(self.node_job_config, str):
return json.loads(self.node_job_config)
return dict(self.node_job_config)
class WorkflowAgentRuntimeSession(DefaultFieldsMixin, Base):
"""Persisted Agent backend session snapshot for one workflow Agent node execution scope.
The snapshot is runtime state returned by Agent backend. It is intentionally
separate from Agent Soul snapshots and workflow node-job config.
"""
__tablename__ = "workflow_agent_runtime_sessions"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="workflow_agent_runtime_session_pkey"),
UniqueConstraint(
"tenant_id",
"workflow_run_id",
"node_id",
"binding_id",
"agent_id",
name="workflow_agent_runtime_session_scope_unique",
),
Index(
"workflow_agent_runtime_session_lookup_idx",
"tenant_id",
"workflow_run_id",
"node_id",
"status",
),
Index("workflow_agent_runtime_session_backend_run_idx", "backend_run_id"),
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str] = mapped_column(String(255), nullable=False)
node_execution_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
binding_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
agent_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
agent_config_snapshot_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
backend_run_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
session_snapshot: Mapped[str] = mapped_column(LongText, nullable=False)
# JSON-encoded list of ``WorkflowAgentSessionLayerSpec`` ({name, type, deps,
# config}). Drives Agent backend cleanup-only runs: the agenton compositor
# rejects a session snapshot whose layer names do not match the cleanup
# composition, so we must replay the same layer graph (minus credential-
# bearing plugin layers) when issuing the cleanup request.
composition_layer_specs: Mapped[str] = mapped_column(LongText, nullable=False, server_default="[]")
status: Mapped[WorkflowAgentRuntimeSessionStatus] = mapped_column(
EnumText(WorkflowAgentRuntimeSessionStatus, length=32),
nullable=False,
default=WorkflowAgentRuntimeSessionStatus.ACTIVE,
)
cleaned_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)

View File

@ -59,7 +59,7 @@ members = ["providers/vdb/*", "providers/trace/*"]
exclude = ["providers/vdb/__pycache__", "providers/trace/__pycache__"]
[tool.uv.sources]
dify-agent = { path = "../dify-agent", editable = true }
dify-agent = { path = "../dify-agent" }
dify-vdb-alibabacloud-mysql = { workspace = true }
dify-vdb-analyticdb = { workspace = true }
dify-vdb-baidu = { workspace = true }

View File

@ -1,298 +0,0 @@
from __future__ import annotations
from unittest.mock import patch
from uuid import uuid4
import pytest
from werkzeug.exceptions import HTTPException
import services
from controllers.console.auth.error import MemberNotInTenantError
from controllers.console.workspace import members as members_module
from controllers.console.workspace.members import MemberCancelInviteApi, MemberUpdateRoleApi, OwnerTransfer
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class WorkspaceMembersIntegrationFactory:
@staticmethod
def create_tenant(db_session_with_containers) -> Tenant:
tenant = Tenant(name=f"Tenant {uuid4()}", plan="basic", status=TenantStatus.NORMAL)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
return tenant
@staticmethod
def create_account(
db_session_with_containers,
*,
email_prefix: str,
tenant: Tenant | None = None,
role: TenantAccountRole = TenantAccountRole.NORMAL,
current: bool = False,
) -> Account:
account = Account(
name=f"Account {uuid4()}",
email=f"{email_prefix}-{uuid4()}@example.com",
password="hashed-password",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
if tenant is not None:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=role,
current=current,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.current_tenant = tenant
return account
@staticmethod
def create_owner_workspace(db_session_with_containers) -> tuple[Tenant, Account]:
tenant = WorkspaceMembersIntegrationFactory.create_tenant(db_session_with_containers)
owner = WorkspaceMembersIntegrationFactory.create_account(
db_session_with_containers,
email_prefix="owner",
tenant=tenant,
role=TenantAccountRole.OWNER,
current=True,
)
return tenant, owner
@staticmethod
def create_owner_transfer_token(account: Account) -> str:
_, token = members_module.AccountService.generate_owner_transfer_token(
account.email,
account=account,
code="123456",
additional_data={},
)
return token
@staticmethod
def get_join(db_session_with_containers, *, tenant: Tenant, account: Account) -> TenantAccountJoin:
tenant_id = tenant.id
account_id = account.id
db_session_with_containers.expire_all()
join = (
db_session_with_containers.query(TenantAccountJoin)
.filter_by(tenant_id=tenant_id, account_id=account_id)
.one()
)
return join
class TestMemberCancelInviteApiWithContainers:
def test_cancel_success(self, flask_app_with_containers, db_session_with_containers):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
member = factory.create_account(db_session_with_containers, email_prefix="member")
with (
flask_app_with_containers.test_request_context("/"),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
patch.object(members_module.TenantService, "remove_member_from_tenant") as mock_remove_member,
):
result, status = method(api, member.id)
assert status == 200
assert result["result"] == "success"
mock_remove_member.assert_called_once()
called_tenant, called_member, called_current_user = mock_remove_member.call_args.args
assert called_tenant.id == tenant.id
assert called_member.id == member.id
assert called_current_user.id == current_user.id
def test_cancel_not_found(self, flask_app_with_containers, db_session_with_containers):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
with (
flask_app_with_containers.test_request_context("/"),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
):
with pytest.raises(HTTPException):
method(api, str(uuid4()))
def test_cancel_cannot_operate_self(self, flask_app_with_containers, db_session_with_containers):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
member = factory.create_account(db_session_with_containers, email_prefix="member")
with (
flask_app_with_containers.test_request_context("/"),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
patch.object(
members_module.TenantService,
"remove_member_from_tenant",
side_effect=services.errors.account.CannotOperateSelfError("x"),
),
):
result, status = method(api, member.id)
assert status == 400
assert result["code"] == "cannot-operate-self"
def test_cancel_no_permission(self, flask_app_with_containers, db_session_with_containers):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
member = factory.create_account(db_session_with_containers, email_prefix="member")
with (
flask_app_with_containers.test_request_context("/"),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
patch.object(
members_module.TenantService,
"remove_member_from_tenant",
side_effect=services.errors.account.NoPermissionError("x"),
),
):
result, status = method(api, member.id)
assert status == 403
assert result["code"] == "forbidden"
def test_cancel_member_not_in_tenant(self, flask_app_with_containers, db_session_with_containers):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
member = factory.create_account(db_session_with_containers, email_prefix="member")
with (
flask_app_with_containers.test_request_context("/"),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
patch.object(
members_module.TenantService,
"remove_member_from_tenant",
side_effect=services.errors.account.MemberNotInTenantError(),
),
):
result, status = method(api, member.id)
assert status == 404
assert result["code"] == "member-not-found"
class TestMemberUpdateRoleApiWithContainers:
def test_update_success(self, flask_app_with_containers, db_session_with_containers):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
member = factory.create_account(
db_session_with_containers,
email_prefix="member",
tenant=tenant,
role=TenantAccountRole.EDITOR,
)
with (
flask_app_with_containers.test_request_context("/", json={"role": "normal"}),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
):
result = method(api, member.id)
if isinstance(result, tuple):
result = result[0]
assert result["result"] == "success"
assert (
factory.get_join(db_session_with_containers, tenant=tenant, account=member).role == TenantAccountRole.NORMAL
)
def test_update_member_not_found(self, flask_app_with_containers, db_session_with_containers):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
with (
flask_app_with_containers.test_request_context("/", json={"role": "normal"}),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
):
with pytest.raises(HTTPException):
method(api, str(uuid4()))
class TestOwnerTransferApiWithContainers:
def test_member_not_in_tenant(self, flask_app_with_containers, db_session_with_containers):
api = OwnerTransfer()
method = unwrap(api.post)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
member = factory.create_account(db_session_with_containers, email_prefix="member")
token = factory.create_owner_transfer_token(current_user)
with (
flask_app_with_containers.test_request_context("/", json={"token": token}),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
):
with pytest.raises(MemberNotInTenantError):
method(api, member.id)
def test_member_not_found(self, flask_app_with_containers, db_session_with_containers):
api = OwnerTransfer()
method = unwrap(api.post)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
token = factory.create_owner_transfer_token(current_user)
with (
flask_app_with_containers.test_request_context("/", json={"token": token}),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
):
with pytest.raises(HTTPException):
method(api, str(uuid4()))
def test_transfer_success(self, flask_app_with_containers, db_session_with_containers):
api = OwnerTransfer()
method = unwrap(api.post)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
member = factory.create_account(
db_session_with_containers,
email_prefix="member",
tenant=tenant,
role=TenantAccountRole.NORMAL,
)
token = factory.create_owner_transfer_token(current_user)
with (
flask_app_with_containers.test_request_context("/", json={"token": token}),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
patch.object(members_module.AccountService, "send_new_owner_transfer_notify_email") as mock_new_owner_email,
patch.object(members_module.AccountService, "send_old_owner_transfer_notify_email") as mock_old_owner_email,
):
result = method(api, member.id)
assert result["result"] == "success"
assert (
factory.get_join(db_session_with_containers, tenant=tenant, account=member).role == TenantAccountRole.OWNER
)
assert (
factory.get_join(db_session_with_containers, tenant=tenant, account=current_user).role
== TenantAccountRole.ADMIN
)
mock_new_owner_email.assert_called_once()
mock_old_owner_email.assert_called_once()

View File

@ -1,84 +0,0 @@
"""
Integration tests for delete_account_task.
These tests keep billing and email dispatch mocked, but exercise the account
lookup through the real Testcontainers PostgreSQL session factory instead of a
patched session_factory mock.
"""
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from models.account import Account
from tasks.delete_account_task import delete_account_task
def _create_account(db_session: Session, *, email: str = "user@example.com") -> Account:
account = Account(
name=f"account-{uuid4()}",
email=email,
)
db_session.add(account)
db_session.commit()
return account
@pytest.fixture
def mock_external_dependencies(mocker):
billing_service = mocker.patch("tasks.delete_account_task.BillingService")
mail_task = mocker.patch("tasks.delete_account_task.send_deletion_success_task")
return billing_service, mail_task
def test_billing_enabled_account_exists_calls_billing_and_sends_email(
db_session_with_containers: Session, mock_external_dependencies, mocker
) -> None:
billing_service, mail_task = mock_external_dependencies
account = _create_account(db_session_with_containers, email="a@b.com")
mocker.patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True)
delete_account_task(account.id)
billing_service.delete_account.assert_called_once_with(account.id)
mail_task.delay.assert_called_once_with(account.email)
def test_billing_disabled_account_exists_sends_email_only(
db_session_with_containers: Session, mock_external_dependencies, mocker
) -> None:
billing_service, mail_task = mock_external_dependencies
account = _create_account(db_session_with_containers, email="x@y.com")
mocker.patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", False)
delete_account_task(account.id)
billing_service.delete_account.assert_not_called()
mail_task.delay.assert_called_once_with(account.email)
def test_billing_enabled_account_not_found_calls_billing_no_email(mock_external_dependencies, mocker, caplog) -> None:
billing_service, mail_task = mock_external_dependencies
account_id = str(uuid4())
mocker.patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True)
delete_account_task(account_id)
billing_service.delete_account.assert_called_once_with(account_id)
mail_task.delay.assert_not_called()
assert any("not found" in record.getMessage().lower() for record in caplog.records)
def test_billing_delete_raises_propagates_and_no_email(
db_session_with_containers: Session, mock_external_dependencies, mocker
) -> None:
billing_service, mail_task = mock_external_dependencies
account = _create_account(db_session_with_containers, email="err@example.com")
billing_service.delete_account.side_effect = RuntimeError("billing down")
mocker.patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True)
with pytest.raises(RuntimeError, match="billing down"):
delete_account_task(account.id)
mail_task.delay.assert_not_called()

View File

@ -1,134 +0,0 @@
"""Integration test for the cleanup request against the real agenton compositor.
The bug fixed by A+D was invisible to unit tests that use ``FakeAgentBackendRunClient``
because the fake client never runs agenton's ``_validate_session_snapshot``. This
test plugs a cleanup request through the real ``Compositor`` (with the same
providers the agent backend wires in production) so that the snapshot-vs-
composition name-order check would fail loudly if the cleanup builder ever
regressed back to the empty-composition shape.
"""
from __future__ import annotations
from typing import cast
import pytest
from agenton.compositor import Compositor, CompositorSessionSnapshot, LayerProvider
from agenton.compositor.schemas import LayerSessionSnapshot
from agenton.layers.base import LifecycleState
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID
from agenton_collections.layers.plain.basic import PromptLayer
from agenton_collections.layers.pydantic_ai import PYDANTIC_AI_HISTORY_LAYER_TYPE_ID, PydanticAIHistoryLayer
from clients.agent_backend import AgentBackendRunRequestBuilder, CleanupLayerSpec
def test_cleanup_request_passes_agenton_snapshot_validation():
"""The cleanup request's composition layer names must match the (filtered)
snapshot's layer names exactly — agenton's compositor enforces this and
the agent backend rejects mismatches as ``run_failed`` asynchronously,
which is the trap A/D fixed."""
# Persisted (non-plugin) layer specs — these are what cleanup will replay.
# We exclude the dify.execution_context layer from this integration check
# because its real provider needs a plugin-daemon HTTP client; the cleanup
# validation we are exercising is the snapshot-vs-composition name check,
# which is purely structural and does not depend on which non-plugin layer
# types appear.
persisted_specs = [
CleanupLayerSpec(
name="workflow_node_job_prompt",
type=PLAIN_PROMPT_LAYER_TYPE_ID,
config={"prefix": "Do the cleanup."},
),
CleanupLayerSpec(name="history", type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID),
]
# Saved snapshot still carries the LLM layer entry — cleanup's
# ``_filter_snapshot_to_specs`` must drop it so names match.
full_snapshot = CompositorSessionSnapshot(
layers=[
LayerSessionSnapshot(
name="workflow_node_job_prompt",
lifecycle_state=LifecycleState.SUSPENDED,
runtime_state={},
),
LayerSessionSnapshot(
name="history",
lifecycle_state=LifecycleState.SUSPENDED,
runtime_state={"messages": []},
),
LayerSessionSnapshot(
name="llm",
lifecycle_state=LifecycleState.SUSPENDED,
runtime_state={},
),
]
)
cleanup_request = AgentBackendRunRequestBuilder().build_cleanup_request(
session_snapshot=full_snapshot,
composition_layer_specs=persisted_specs,
)
# Drive the real agenton compositor through ``from_config`` + ``_create_run``
# the same way the agent backend's RunScheduler does. ``_create_run`` is the
# private path that calls ``_validate_session_snapshot``; we use it directly
# to keep the test synchronous (no async ``enter()`` lifecycle needed —
# validation is the only thing under test).
config = {
"schema_version": 1,
"layers": [
{"name": layer.name, "type": layer.type, "deps": dict(layer.deps), "metadata": dict(layer.metadata)}
for layer in cleanup_request.composition.layers
],
}
compositor = Compositor.from_config(
config,
providers=[
LayerProvider.from_layer_type(PromptLayer),
LayerProvider.from_layer_type(PydanticAIHistoryLayer),
],
)
layer_configs = {layer.name: layer.config for layer in cleanup_request.composition.layers}
# This is the call that would raise ``ValueError`` if the cleanup snapshot
# and composition disagreed on layer names — the exact failure mode the
# original ``layers=[]`` cleanup hit.
run = compositor._create_run( # type: ignore[reportPrivateUsage]
configs=cast(dict[str, object], layer_configs),
session_snapshot=cleanup_request.session_snapshot,
)
assert list(run.slots.keys()) == ["workflow_node_job_prompt", "history"]
def test_cleanup_request_with_mismatched_specs_would_be_rejected_by_agenton():
"""Regression sentinel: if a future refactor stops filtering the snapshot,
agenton would reject the request — and that rejection is what the runtime
fix is preventing. We confirm the validator does fail when given the
pre-fix shape so the previous test's success is not a coincidence."""
snapshot_with_extra = CompositorSessionSnapshot(
layers=[
LayerSessionSnapshot(
name="history",
lifecycle_state=LifecycleState.SUSPENDED,
runtime_state={},
),
LayerSessionSnapshot(
name="llm", # extra layer not in composition
lifecycle_state=LifecycleState.SUSPENDED,
runtime_state={},
),
]
)
compositor = Compositor.from_config(
{
"schema_version": 1,
"layers": [{"name": "history", "type": PYDANTIC_AI_HISTORY_LAYER_TYPE_ID, "deps": {}, "metadata": {}}],
},
providers=[LayerProvider.from_layer_type(PydanticAIHistoryLayer)],
)
with pytest.raises(ValueError, match="layer names must match"):
compositor._create_run( # type: ignore[reportPrivateUsage]
configs={},
session_snapshot=snapshot_with_extra,
)

View File

@ -63,25 +63,3 @@ def test_fake_client_cancel_run_returns_cancelled_status():
assert cancelled.run_id == "fake-run-1"
assert cancelled.status == "cancelled"
def test_fake_client_paused_scenario_returns_paused_status_and_event():
"""The paused scenario exists for HITL-style flows; both ``wait_run`` and
the event stream must report the pause so consumers can branch on it."""
client = FakeAgentBackendRunClient(scenario=FakeAgentBackendScenario.PAUSED)
status = client.wait_run("fake-run-1")
events = list(client.stream_events("fake-run-1"))
assert status.status == "paused"
assert status.error is None
assert events[-1].type == "run_paused"
assert events[-1].data.reason == "human_input_required"
def test_fake_client_success_wait_run_returns_succeeded_status():
"""Covers the default SUCCESS branch of ``wait_run`` directly."""
status = FakeAgentBackendRunClient().wait_run("fake-run-1")
assert status.status == "succeeded"
assert status.error is None

View File

@ -1,23 +1,15 @@
from typing import Any, cast
import pytest
from agenton.compositor import CompositorSessionSnapshot
from agenton.compositor.schemas import LayerSessionSnapshot
from agenton.layers import ExitIntent
from agenton.layers.base import LifecycleState
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
from agenton_collections.layers.pydantic_ai import PYDANTIC_AI_HISTORY_LAYER_TYPE_ID
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID
from dify_agent.layers.dify_plugin import (
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
DifyPluginLLMLayerConfig,
DifyPluginToolConfig,
DifyPluginToolsLayerConfig,
)
from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID
from dify_agent.protocol import (
DIFY_AGENT_HISTORY_LAYER_ID,
DIFY_AGENT_MODEL_LAYER_ID,
DIFY_AGENT_OUTPUT_LAYER_ID,
CreateRunRequest,
@ -34,7 +26,6 @@ from clients.agent_backend import (
AgentBackendOutputConfig,
AgentBackendRunRequestBuilder,
AgentBackendWorkflowNodeRunInput,
CleanupLayerSpec,
redact_for_agent_backend_log,
)
@ -80,11 +71,10 @@ def test_request_builder_outputs_dify_agent_create_run_request():
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
WORKFLOW_USER_PROMPT_LAYER_ID,
DIFY_EXECUTION_CONTEXT_LAYER_ID,
DIFY_AGENT_HISTORY_LAYER_ID,
DIFY_AGENT_MODEL_LAYER_ID,
DIFY_AGENT_OUTPUT_LAYER_ID,
]
assert request.on_exit.default is ExitIntent.SUSPEND
assert request.on_exit.default is ExitIntent.DELETE
assert request.idempotency_key == "workflow-run-1:node-execution-1"
assert request.metadata == {"workflow_id": "workflow-1", "node_id": "node-1"}
@ -109,10 +99,9 @@ def test_request_builder_sets_model_and_output_layer_contract_ids():
layers = {layer.name: layer for layer in request.composition.layers}
assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID].type == DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID
assert cast(DifyExecutionContextLayerConfig, layers[DIFY_EXECUTION_CONTEXT_LAYER_ID].config).user_id == "user-1"
assert layers[DIFY_AGENT_HISTORY_LAYER_ID].type == PYDANTIC_AI_HISTORY_LAYER_TYPE_ID
assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID].config.user_id == "user-1"
assert layers[DIFY_AGENT_MODEL_LAYER_ID].type == DIFY_PLUGIN_LLM_LAYER_TYPE_ID
assert cast(DifyPluginLLMLayerConfig, layers[DIFY_AGENT_MODEL_LAYER_ID].config).plugin_id == "langgenius/openai"
assert layers[DIFY_AGENT_MODEL_LAYER_ID].config.plugin_id == "langgenius/openai"
assert layers[DIFY_AGENT_MODEL_LAYER_ID].deps == {"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID}
assert layers[DIFY_AGENT_OUTPUT_LAYER_ID].type == DIFY_OUTPUT_LAYER_TYPE_ID
@ -141,92 +130,16 @@ def test_request_builder_adds_dify_plugin_tools_layer_when_configured():
assert layers[DIFY_PLUGIN_TOOLS_LAYER_ID].type == DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID
assert layers[DIFY_PLUGIN_TOOLS_LAYER_ID].deps == {"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID}
tools_config = cast(DifyPluginToolsLayerConfig, layers[DIFY_PLUGIN_TOOLS_LAYER_ID].config)
assert tools_config.tools[0].tool_name == "current_time"
assert layers[DIFY_PLUGIN_TOOLS_LAYER_ID].config.tools[0].tool_name == "current_time"
def test_request_builder_can_delete_on_exit_for_cleanup_paths():
def test_request_builder_can_suspend_on_exit_for_resume_or_babysit_paths():
run_input = _run_input()
run_input.suspend_on_exit = False
run_input.suspend_on_exit = True
request = AgentBackendRunRequestBuilder().build_for_workflow_node(run_input)
assert request.on_exit.default is ExitIntent.DELETE
def test_request_builder_builds_cleanup_request_replays_persisted_layer_specs():
"""The cleanup request must replay the persisted (non-plugin) layer specs
and filter the snapshot to match so the agenton compositor's
snapshot-vs-composition name-order validator passes."""
session_snapshot = CompositorSessionSnapshot(
layers=[
LayerSessionSnapshot(name="history", lifecycle_state=LifecycleState.SUSPENDED, runtime_state={"k": 1}),
LayerSessionSnapshot(name="llm", lifecycle_state=LifecycleState.SUSPENDED, runtime_state={}),
]
)
specs = [CleanupLayerSpec(name="history", type="pydantic_ai.history")]
request = AgentBackendRunRequestBuilder().build_cleanup_request(
session_snapshot=session_snapshot,
composition_layer_specs=specs,
idempotency_key="run-1:node-1:binding-1:agent-session-cleanup",
metadata={"workflow_run_id": "run-1"},
)
assert [layer.name for layer in request.composition.layers] == ["history"]
assert request.session_snapshot is not None
assert [layer.name for layer in request.session_snapshot.layers] == ["history"]
assert request.on_exit.default is ExitIntent.DELETE
assert request.idempotency_key == "run-1:node-1:binding-1:agent-session-cleanup"
assert request.metadata["agent_backend_lifecycle"] == "session_cleanup"
def test_request_builder_rejects_empty_composition_layer_specs():
"""Empty specs would put us back in the original ``layers=[]`` trap that
fails on agenton's snapshot-vs-composition validation."""
with pytest.raises(ValueError, match="composition_layer_specs"):
AgentBackendRunRequestBuilder().build_cleanup_request(
session_snapshot=CompositorSessionSnapshot(layers=[]),
composition_layer_specs=[],
)
def test_extract_cleanup_layer_specs_drops_plugin_layers_keeps_configs():
from dify_agent.protocol import RunComposition, RunLayerSpec
from clients.agent_backend import extract_cleanup_layer_specs
composition = RunComposition(
layers=[
RunLayerSpec(
name="agent_soul_prompt",
type="plain.prompt",
config=PromptLayerConfig(prefix="hello"),
),
RunLayerSpec(
name="llm",
type="dify.plugin.llm",
config=None, # protocol allows None; the redacted config is what matters
),
RunLayerSpec(
name="tools",
type="dify.plugin.tools",
),
RunLayerSpec(
name="history",
type="pydantic_ai.history",
),
]
)
specs = extract_cleanup_layer_specs(composition)
assert [spec.name for spec in specs] == ["agent_soul_prompt", "history"]
# Non-plugin configs are dumped as JSON-compatible dicts so the persisted
# row can be replayed without holding live pydantic instances.
soul_config = specs[0].config
assert isinstance(soul_config, dict)
assert soul_config.get("prefix") == "hello"
assert request.on_exit.default is ExitIntent.SUSPEND
def test_request_builder_rejects_blank_prompts():
@ -246,6 +159,6 @@ def test_request_builder_rejects_blank_prompts():
def test_redact_for_agent_backend_log_hides_credentials():
request = AgentBackendRunRequestBuilder().build_for_workflow_node(_run_input())
redacted = cast(dict[str, Any], redact_for_agent_backend_log(request))
redacted = redact_for_agent_backend_log(request)
assert redacted["composition"]["layers"][5]["config"]["credentials"] == "[REDACTED]"
assert redacted["composition"]["layers"][4]["config"]["credentials"] == "[REDACTED]"

View File

@ -34,6 +34,7 @@ def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.RuleGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(generator_module.LLMGenerator, "generate_rule_config", lambda **_kwargs: {"rules": []})
with app.test_request_context(
@ -41,7 +42,7 @@ def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
method="POST",
json={"instruction": "do it", "model_config": _model_config_payload()},
):
response = method("t1")
response = method()
assert response == {"rules": []}
@ -50,6 +51,8 @@ def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatc
api = generator_module.RuleCodeGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
def _raise(*_args, **_kwargs):
raise ProviderTokenNotInitError("missing token")
@ -61,13 +64,15 @@ def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatc
json={"instruction": "do it", "model_config": _model_config_payload()},
):
with pytest.raises(ProviderNotInitializeError):
method("t1")
method()
def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None))
with app.test_request_context(
@ -80,7 +85,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch
"model_config": _model_config_payload(),
},
):
response, status = method("t1")
response, status = method()
assert status == 400
assert response["error"] == "app app-1 not found"
@ -90,6 +95,8 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
_install_workflow_service(monkeypatch, workflow=None)
@ -104,7 +111,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
"model_config": _model_config_payload(),
},
):
response, status = method("t1")
response, status = method()
assert status == 400
assert response["error"] == "workflow app-1 not found"
@ -114,6 +121,8 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
@ -130,7 +139,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
"model_config": _model_config_payload(),
},
):
response, status = method("t1")
response, status = method()
assert status == 400
assert response["error"] == "node node-1 not found"
@ -140,6 +149,8 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
@ -163,7 +174,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
"model_config": _model_config_payload(),
},
):
response = method("t1")
response = method()
assert response == {"code": "x"}
@ -172,6 +183,7 @@ def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(
generator_module.LLMGenerator,
"instruction_modify_legacy",
@ -189,7 +201,7 @@ def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch
"model_config": _model_config_payload(),
},
):
response = method("t1")
response = method()
assert response == {"instruction": "ok"}
@ -198,6 +210,8 @@ def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.Monke
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
@ -209,7 +223,7 @@ def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.Monke
"model_config": _model_config_payload(),
},
):
response, status = method("t1")
response, status = method()
assert status == 400
assert response["error"] == "incompatible parameters"

View File

@ -121,6 +121,7 @@ class TestAppMCPServerController:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch("controllers.console.app.mcp_server.current_account_with_tenant", return_value=(None, "tenant-1")),
patch("controllers.console.app.mcp_server.db.session.add"),
patch("controllers.console.app.mcp_server.db.session.commit"),
patch("controllers.console.app.mcp_server.AppMCPServer.generate_server_code", return_value="server-code"),
@ -130,7 +131,7 @@ class TestAppMCPServerController:
),
):
response, status_code = method(
api, "tenant-1", app_model=SimpleNamespace(id="app-1", name="Demo App", description="App description")
api, app_model=SimpleNamespace(id="app-1", name="Demo App", description="App description")
)
assert response == {"id": "server-1"}

View File

@ -3,18 +3,22 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import HTTPException
import services
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
EmailCodeError,
InvalidEmailError,
InvalidTokenError,
MemberNotInTenantError,
NotOwnerError,
OwnerTransferLimitError,
)
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
from controllers.console.workspace.members import (
DatasetOperatorMemberListApi,
MemberCancelInviteApi,
MemberInviteEmailApi,
MemberListApi,
MemberUpdateRoleApi,
@ -247,7 +251,135 @@ class TestMemberInviteEmailApi:
assert result["invitation_results"][0]["status"] == "failed"
class TestMemberCancelInviteApi:
def test_cancel_success(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
):
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 200
assert result["result"] == "success"
def test_cancel_not_found(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get") as get_mock,
):
get_mock.return_value = None
with pytest.raises(HTTPException):
method(api, "x")
def test_cancel_cannot_operate_self(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.CannotOperateSelfError("x"),
),
):
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 400
def test_cancel_no_permission(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.NoPermissionError("x"),
),
):
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 403
def test_cancel_member_not_in_tenant(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.MemberNotInTenantError(),
),
):
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 404
class TestMemberUpdateRoleApi:
def test_update_success(self, app: Flask):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
member = MagicMock()
payload = {"role": "normal"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get", return_value=member),
patch("controllers.console.workspace.members.TenantService.update_member_role"),
):
result = method(api, "id")
if isinstance(result, tuple):
result = result[0]
assert result["result"] == "success"
def test_update_invalid_role(self, app: Flask):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
@ -259,6 +391,23 @@ class TestMemberUpdateRoleApi:
assert status == 400
def test_update_member_not_found(self, app: Flask):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
payload = {"role": "normal"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.members.current_account_with_tenant",
return_value=(MagicMock(current_tenant=MagicMock()), "t1"),
),
patch("controllers.console.workspace.members.db.session.get", return_value=None),
):
with pytest.raises(HTTPException):
method(api, "id")
class TestDatasetOperatorMemberListApi:
def test_get_success(self, app: Flask):
@ -488,3 +637,27 @@ class TestOwnerTransferApi:
):
with pytest.raises(InvalidTokenError):
method(api, "2")
def test_member_not_in_tenant(self, app: Flask):
api = OwnerTransfer()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
member = MagicMock()
payload = {"token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
return_value={"email": "a@test.com"},
),
patch("controllers.console.workspace.members.db.session.get", return_value=member),
patch("controllers.console.workspace.members.TenantService.is_member", return_value=False),
):
with pytest.raises(MemberNotInTenantError):
method(api, "2")

View File

@ -1,4 +1,4 @@
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
@ -34,11 +34,15 @@ class TestDefaultModelApi:
"/",
query_string={"model_type": ModelType.LLM},
),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_default_model_of_model_type.return_value = {"model": "gpt-4"}
result = method(api, "tenant1")
result = method(api)
assert "data" in result
@ -58,9 +62,13 @@ class TestDefaultModelApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "tenant1")
result = method(api)
assert result["result"] == "success"
@ -70,11 +78,12 @@ class TestDefaultModelApi:
with (
app.test_request_context("/", query_string={"model_type": ModelType.LLM}),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_default_model_of_model_type.return_value = None
result = method(api, "t1")
result = method(api)
assert "data" in result
@ -86,11 +95,15 @@ class TestModelProviderModelApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_models_by_provider.return_value = []
result = method(api, "tenant1", "openai")
result = method(api, "openai")
assert "data" in result
@ -109,10 +122,14 @@ class TestModelProviderModelApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
patch("controllers.console.workspace.models.ModelLoadBalancingService"),
):
result, status = method(api, "tenant1", "openai")
result, status = method(api, "openai")
assert status == 200
@ -127,9 +144,13 @@ class TestModelProviderModelApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result, status = method(api, "tenant1", "openai")
result, status = method(api, "openai")
assert status == 204
@ -139,11 +160,12 @@ class TestModelProviderModelApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_models_by_provider.return_value = []
result = method(api, "t1", "openai")
result = method(api, "openai")
assert "data" in result
@ -161,6 +183,10 @@ class TestModelProviderModelCredentialApi:
"model_type": ModelType.LLM,
},
),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as provider_service,
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb_service,
):
@ -172,7 +198,7 @@ class TestModelProviderModelCredentialApi:
provider_service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
lb_service.return_value.get_load_balancing_configs.return_value = (False, [])
result = method(api, "tenant1", "openai")
result = method(api, "openai")
assert "credentials" in result
@ -188,9 +214,13 @@ class TestModelProviderModelCredentialApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result, status = method(api, "tenant1", "openai")
result, status = method(api, "openai")
assert status == 201
@ -200,6 +230,7 @@ class TestModelProviderModelCredentialApi:
with (
app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM}),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb,
):
@ -207,7 +238,7 @@ class TestModelProviderModelCredentialApi:
service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
lb.return_value.get_load_balancing_configs.return_value = (False, [])
result = method(api, "t1", "openai")
result = method(api, "openai")
assert result["credentials"] == {}
@ -223,9 +254,10 @@ class TestModelProviderModelCredentialApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result, status = method(api, "t1", "openai")
result, status = method(api, "openai")
assert status == 204
@ -243,9 +275,13 @@ class TestModelProviderModelCredentialSwitchApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "tenant1", "openai")
result = method(api, "openai")
assert result["result"] == "success"
@ -262,9 +298,13 @@ class TestModelEnableDisableApis:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "tenant1", "openai")
result = method(api, "openai")
assert result["result"] == "success"
@ -279,9 +319,13 @@ class TestModelEnableDisableApis:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "tenant1", "openai")
result = method(api, "openai")
assert result["result"] == "success"
@ -299,9 +343,13 @@ class TestModelProviderModelValidateApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "tenant1", "openai")
result = method(api, "openai")
assert result["result"] == "success"
@ -318,11 +366,15 @@ class TestModelProviderModelValidateApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.validate_model_credentials.side_effect = CredentialsValidateFailedError("invalid")
result = method(api, "tenant1", "openai")
result = method(api, "openai")
assert result["result"] == "error"
@ -334,11 +386,15 @@ class TestParameterAndAvailableModels:
with (
app.test_request_context("/", query_string={"model": "gpt-4"}),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_model_parameter_rules.return_value = []
result = method(api, "tenant1", "openai")
result = method(api, "openai")
assert "data" in result
@ -348,11 +404,15 @@ class TestParameterAndAvailableModels:
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_models_by_model_type.return_value = []
result = method(api, "tenant1", ModelType.LLM)
result = method(api, ModelType.LLM)
assert "data" in result
@ -362,11 +422,12 @@ class TestParameterAndAvailableModels:
with (
app.test_request_context("/", query_string={"model": "gpt"}),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_model_parameter_rules.return_value = []
result = method(api, "t1", "openai")
result = method(api, "openai")
assert result["data"] == []
@ -376,10 +437,11 @@ class TestParameterAndAvailableModels:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_models_by_model_type.return_value = []
result = method(api, "t1", ModelType.LLM)
result = method(api, ModelType.LLM)
assert result["data"] == []

View File

@ -1,73 +1,66 @@
from controllers.openapi.auth.composition import account_pipeline, auth_router, external_sso_pipeline
from controllers.openapi.auth.flow import When
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
from libs.oauth_bearer import TokenType
from unittest.mock import patch
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE, _resolve_app_authz_strategy
from controllers.openapi.auth.pipeline import Pipeline
from controllers.openapi.auth.steps import (
AppAuthzCheck,
AppResolver,
BearerCheck,
CallerMount,
ScopeCheck,
SurfaceCheck,
WorkspaceMembershipCheck,
)
from controllers.openapi.auth.strategies import (
AccountMounter,
AclStrategy,
EndUserMounter,
MembershipStrategy,
)
from libs.oauth_bearer import SubjectType
def test_account_pipeline_is_auth_pipeline():
assert isinstance(account_pipeline, AuthPipeline)
def test_pipeline_is_composed():
assert isinstance(OAUTH_BEARER_PIPELINE, Pipeline)
def test_external_sso_pipeline_is_auth_pipeline():
assert isinstance(external_sso_pipeline, AuthPipeline)
def test_pipeline_step_order():
"""BearerCheck → SurfaceCheck → ScopeCheck → AppResolver →
WorkspaceMembershipCheck → AppAuthzCheck → CallerMount.
SurfaceCheck enforces the dfoa_/dfoe_ surface split + emits
`openapi.wrong_surface_denied`. Rate-limit is enforced inside
`BearerAuthenticator.authenticate`, not as a separate pipeline step."""
steps = OAUTH_BEARER_PIPELINE._steps
assert isinstance(steps[0], BearerCheck)
assert isinstance(steps[1], SurfaceCheck)
assert isinstance(steps[2], ScopeCheck)
assert isinstance(steps[3], AppResolver)
assert isinstance(steps[4], WorkspaceMembershipCheck)
assert isinstance(steps[5], AppAuthzCheck)
assert isinstance(steps[6], CallerMount)
def test_auth_router_is_pipeline_router():
assert isinstance(auth_router, PipelineRouter)
def test_pipeline_surface_check_accepts_account_only():
"""Current pipeline serves /apps/<id>/run — account surface only."""
surface = OAUTH_BEARER_PIPELINE._steps[1]
assert isinstance(surface, SurfaceCheck)
assert surface._accepted == frozenset({SubjectType.ACCOUNT})
def test_account_pipeline_prepare_has_four_entries():
assert len(account_pipeline._prepare) == 4
def test_caller_mount_has_both_mounters():
cm = OAUTH_BEARER_PIPELINE._steps[6]
kinds = {type(m) for m in cm._mounters}
assert AccountMounter in kinds
assert EndUserMounter in kinds
def test_account_auth_list_has_five_entries():
assert len(account_pipeline._auth) == 5
@patch("controllers.openapi.auth.composition.FeatureService")
def test_strategy_resolver_picks_acl_when_enabled(fs):
fs.get_system_features.return_value.webapp_auth.enabled = True
assert isinstance(_resolve_app_authz_strategy(), AclStrategy)
def test_external_sso_pipeline_prepare_has_four_entries():
assert len(external_sso_pipeline._prepare) == 4
def test_external_sso_auth_list_has_three_entries():
assert len(external_sso_pipeline._auth) == 3
def test_account_pipeline_has_unconditional_load_account():
non_when = [s for s in account_pipeline._prepare if not isinstance(s, When)]
assert len(non_when) == 1
def test_external_sso_pipeline_all_prepare_entries_are_when():
assert all(isinstance(s, When) for s in external_sso_pipeline._prepare)
def test_first_auth_entry_is_check_scope_in_both_pipelines():
assert not isinstance(account_pipeline._auth[0], When)
assert not isinstance(external_sso_pipeline._auth[0], When)
def test_remaining_auth_entries_are_when_for_account():
assert all(isinstance(s, When) for s in account_pipeline._auth[1:])
def test_remaining_auth_entries_are_when_for_external_sso():
assert all(isinstance(s, When) for s in external_sso_pipeline._auth[1:])
def test_router_routes_contain_both_token_types():
assert TokenType.OAUTH_ACCOUNT in auth_router._routes
assert TokenType.OAUTH_EXTERNAL_SSO in auth_router._routes
def test_external_sso_route_has_ee_required_edition():
route = auth_router._routes[TokenType.OAUTH_EXTERNAL_SSO]
assert isinstance(route, PipelineRoute)
from controllers.openapi.auth.data import Edition
assert route.required_edition == frozenset({Edition.EE})
def test_account_route_has_no_required_edition():
route = auth_router._routes[TokenType.OAUTH_ACCOUNT]
assert isinstance(route, PipelineRoute)
assert route.required_edition is None
@patch("controllers.openapi.auth.composition.FeatureService")
def test_strategy_resolver_picks_membership_when_disabled(fs):
fs.get_system_features.return_value.webapp_auth.enabled = False
assert isinstance(_resolve_app_authz_strategy(), MembershipStrategy)

View File

@ -1,143 +0,0 @@
from unittest.mock import MagicMock, patch
from controllers.openapi.auth.conditions import (
EDITION_CE,
EDITION_EE,
EDITION_SAAS,
LOADED_APP_IS_PRIVATE,
PATH_HAS_APP_ID,
TOKEN_IS_OAUTH_ACCOUNT,
TOKEN_IS_OAUTH_EXTERNAL_SSO,
WEBAPP_AUTH_ENABLED,
Cond,
config_cond,
data_cond,
request_cond,
)
from controllers.openapi.auth.data import AuthData, Edition, RequestContext
from libs.oauth_bearer import TokenType
from services.enterprise.enterprise_service import WebAppAccessMode
def _ctx(token_type=TokenType.OAUTH_ACCOUNT, path_params=None):
return RequestContext(
token_type=token_type,
path_params=path_params or {},
)
def _data(**kwargs):
defaults: dict = {"token_type": TokenType.OAUTH_ACCOUNT, "token_hash": "x", "scopes": frozenset()}
defaults.update(kwargs)
return AuthData(**defaults)
def test_and_both_true():
a = Cond(lambda ctx, _: True)
b = Cond(lambda ctx, _: True)
assert (a & b)(_ctx()) is True
def test_and_one_false():
a = Cond(lambda ctx, _: True)
b = Cond(lambda ctx, _: False)
assert (a & b)(_ctx()) is False
def test_or_one_true():
a = Cond(lambda ctx, _: False)
b = Cond(lambda ctx, _: True)
assert (a | b)(_ctx()) is True
def test_or_both_false():
a = Cond(lambda ctx, _: False)
b = Cond(lambda ctx, _: False)
assert (a | b)(_ctx()) is False
def test_invert():
a = Cond(lambda ctx, _: True)
assert (~a)(_ctx()) is False
def test_chain_and_or():
always_true = Cond(lambda ctx, _: True)
always_false = Cond(lambda ctx, _: False)
assert ((always_true | always_false) & always_true)(_ctx()) is True
def test_request_cond_ignores_data():
c = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_ACCOUNT)
assert c(_ctx(TokenType.OAUTH_ACCOUNT)) is True
assert c(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is False
def test_data_cond_returns_false_when_data_none():
c = data_cond(lambda data: True)
assert c(_ctx(), None) is False
def test_data_cond_evaluates_when_data_present():
c = data_cond(lambda data: data.token_hash == "secret")
assert c(_ctx(), _data(token_hash="secret")) is True
assert c(_ctx(), _data(token_hash="other")) is False
def test_config_cond_ignores_ctx_and_data():
c = config_cond(lambda: True)
assert c(_ctx()) is True
c2 = config_cond(lambda: False)
assert c2(_ctx(), _data()) is False
def test_token_is_oauth_account():
assert TOKEN_IS_OAUTH_ACCOUNT(_ctx(TokenType.OAUTH_ACCOUNT)) is True
assert TOKEN_IS_OAUTH_ACCOUNT(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is False
def test_token_is_oauth_external_sso():
assert TOKEN_IS_OAUTH_EXTERNAL_SSO(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is True
def test_path_has_app_id_true():
assert PATH_HAS_APP_ID(_ctx(path_params={"app_id": "abc"})) is True
def test_path_has_app_id_false():
assert PATH_HAS_APP_ID(_ctx(path_params={})) is False
def test_edition_ce():
with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.CE):
assert EDITION_CE(_ctx()) is True
assert EDITION_EE(_ctx()) is False
assert EDITION_SAAS(_ctx()) is False
def test_edition_ee():
with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.EE):
assert EDITION_EE(_ctx()) is True
assert EDITION_CE(_ctx()) is False
def test_edition_saas():
with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.SAAS):
assert EDITION_SAAS(_ctx()) is True
def test_webapp_auth_enabled():
mock_features = MagicMock()
mock_features.webapp_auth.enabled = True
with patch("controllers.openapi.auth.conditions.FeatureService.get_system_features", return_value=mock_features):
assert WEBAPP_AUTH_ENABLED(_ctx()) is True
def test_loaded_app_is_private():
data_private = _data(app_access_mode=WebAppAccessMode.PRIVATE)
data_public = _data(app_access_mode=WebAppAccessMode.PUBLIC)
data_none = _data(app_access_mode=None)
assert LOADED_APP_IS_PRIVATE(_ctx(), data_private) is True
assert LOADED_APP_IS_PRIVATE(_ctx(), data_public) is False
assert LOADED_APP_IS_PRIVATE(_ctx(), data_none) is False
assert LOADED_APP_IS_PRIVATE(_ctx(), None) is False

View File

@ -0,0 +1,21 @@
from controllers.openapi.auth.context import Context
def test_context_starts_unpopulated():
ctx = Context(required_scope="apps:run")
assert ctx.bearer_token is None
assert ctx.path_params == {}
assert ctx.subject_type is None
assert ctx.subject_email is None
assert ctx.account_id is None
assert ctx.scopes == frozenset()
assert ctx.app is None
assert ctx.tenant is None
assert ctx.caller is None
assert ctx.caller_kind is None
def test_context_fields_are_mutable():
ctx = Context(required_scope="apps:run")
ctx.scopes = frozenset({"full"})
assert "full" in ctx.scopes

View File

@ -1,117 +0,0 @@
import uuid
from unittest.mock import patch
import pytest
from pydantic import ValidationError
from controllers.openapi.auth.data import (
AuthData,
Edition,
ExternalIdentity,
RequestContext,
current_edition,
)
from libs.oauth_bearer import Scope, TokenType
def test_current_edition_saas():
with patch("controllers.openapi.auth.data.dify_config") as cfg:
cfg.EDITION = "CLOUD"
cfg.ENTERPRISE_ENABLED = True
assert current_edition() == Edition.SAAS
def test_current_edition_ee():
with patch("controllers.openapi.auth.data.dify_config") as cfg:
cfg.EDITION = "SELF_HOSTED"
cfg.ENTERPRISE_ENABLED = True
assert current_edition() == Edition.EE
def test_current_edition_ce():
with patch("controllers.openapi.auth.data.dify_config") as cfg:
cfg.EDITION = "SELF_HOSTED"
cfg.ENTERPRISE_ENABLED = False
assert current_edition() == Edition.CE
def test_external_identity_frozen():
ei = ExternalIdentity(email="a@b.com", issuer="idp")
with pytest.raises(ValidationError):
ei.email = "other@b.com" # type: ignore[misc]
def test_external_identity_issuer_optional():
ei = ExternalIdentity(email="a@b.com")
assert ei.issuer is None
def test_request_context_frozen():
ctx = RequestContext(
token_type=TokenType.OAUTH_ACCOUNT,
path_params={"app_id": "123"},
)
with pytest.raises(ValidationError):
ctx.token_type = TokenType.OAUTH_EXTERNAL_SSO # type: ignore[misc]
def test_request_context_scope_optional():
ctx = RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={})
assert ctx.scope is None
def test_auth_data_is_mutable():
data = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
token_hash="abc",
scopes=frozenset({Scope.FULL}),
)
data.token_type = TokenType.OAUTH_EXTERNAL_SSO
assert data.token_type == TokenType.OAUTH_EXTERNAL_SSO
def test_auth_data_path_params_defaults_empty():
data = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
token_hash="abc",
scopes=frozenset(),
)
assert data.path_params == {}
def test_auth_data_account_id_optional():
data = AuthData(
token_type=TokenType.OAUTH_EXTERNAL_SSO,
token_hash="abc",
scopes=frozenset({Scope.APPS_RUN}),
external_identity=ExternalIdentity(email="u@sso.com"),
)
assert data.account_id is None
def test_auth_data_external_identity_none_for_account():
data = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
account_id=uuid.uuid4(),
token_hash="abc",
scopes=frozenset({Scope.FULL}),
)
assert data.external_identity is None
def test_auth_data_tenants_default_empty():
data = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
token_hash="abc",
scopes=frozenset(),
)
assert data.tenants == {}
def test_auth_data_token_id_optional():
data = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
token_hash="abc",
scopes=frozenset(),
)
assert data.token_id is None

View File

@ -1,42 +0,0 @@
import inspect
from controllers.openapi.auth.conditions import Cond
from controllers.openapi.auth.data import AuthData, RequestContext
from controllers.openapi.auth.flow import When
from libs.oauth_bearer import TokenType
def _ctx():
return RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={})
def _data():
return AuthData(token_type=TokenType.OAUTH_ACCOUNT, token_hash="x", scopes=frozenset())
def test_applies_returns_true_when_condition_true():
w = When(Cond(lambda ctx, _: True), then=lambda b: None)
assert w.applies(_ctx()) is True
def test_applies_returns_false_when_condition_false():
w = When(Cond(lambda ctx, _: False), then=lambda b: None)
assert w.applies(_ctx()) is False
def test_applies_with_data():
w = When(Cond(lambda ctx, data: data is not None), then=lambda b: None)
assert w.applies(_ctx(), _data()) is True
assert w.applies(_ctx(), None) is False
def test_call_invokes_step():
calls = []
w = When(Cond(lambda ctx, _: True), then=lambda arg: calls.append(arg))
w("payload")
assert calls == ["payload"]
def test_then_is_keyword_only():
sig = inspect.signature(When.__init__)
assert sig.parameters["then"].kind.name == "KEYWORD_ONLY"

View File

@ -1,269 +1,59 @@
import uuid
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from controllers.openapi.auth.data import AuthData, Edition
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
from libs.oauth_bearer import Scope, TokenType
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.pipeline import Pipeline
def _make_identity(
token_type=TokenType.OAUTH_ACCOUNT,
account_id=None,
scopes=None,
token_hash="testhash",
subject_email=None,
subject_issuer=None,
verified_tenants=None,
token_id=None,
):
identity = MagicMock()
identity.token_type = token_type
identity.account_id = account_id or uuid.uuid4()
identity.scopes = scopes or frozenset({Scope.FULL})
identity.token_hash = token_hash
identity.subject_email = subject_email
identity.subject_issuer = subject_issuer
identity.verified_tenants = verified_tenants or {}
identity.token_id = token_id or uuid.uuid4()
return identity
def test_run_invokes_each_step_in_order():
calls = []
class S:
def __init__(self, tag):
self.tag = tag
@pytest.fixture
def app():
return Flask(__name__)
def __call__(self, ctx):
calls.append(self.tag)
Pipeline(S("a"), S("b"), S("c")).run(Context(required_scope="x"))
assert calls == ["a", "b", "c"]
def _make_router(token_type=TokenType.OAUTH_ACCOUNT, prepare=None, auth=None):
pipeline = AuthPipeline(prepare=prepare or [], auth=auth or [])
return PipelineRouter({token_type: PipelineRoute(pipeline)})
def test_run_short_circuits_on_raise():
calls = []
def _fake_identity():
return _make_identity()
class Boom:
def __call__(self, ctx):
raise RuntimeError("boom")
class Tail:
def __call__(self, ctx):
calls.append("ran")
# --- PipelineRouter.guard ---
with pytest.raises(RuntimeError):
Pipeline(Boom(), Tail()).run(Context(required_scope="x"))
assert calls == []
def test_guard_passes_auth_data_to_view(app):
router = _make_router()
received = {}
def test_guard_decorator_runs_pipeline_and_unpacks_handler_kwargs():
seen = {}
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
):
mock_auth.return_value.authenticate.return_value = _fake_identity()
class FakeStep:
def __call__(self, ctx):
ctx.app = "APP"
ctx.caller = "CALLER"
ctx.caller_kind = "account"
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def view(*, auth_data):
received["data"] = auth_data
pipeline = Pipeline(FakeStep())
view()
@pipeline.guard(scope="apps:run")
def handler(app_model, caller, caller_kind):
seen["app_model"] = app_model
seen["caller"] = caller
seen["caller_kind"] = caller_kind
return "ok"
assert isinstance(received["data"], AuthData)
def test_guard_edition_gate_returns_404(app):
router = _make_router()
with app.test_request_context("/test"):
with patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE):
@router.guard(scope=Scope.FULL, edition=frozenset({Edition.EE}))
def view(*, auth_data):
pass
with pytest.raises(NotFound):
view()
def test_guard_token_type_gate_returns_403(app):
router = _make_router()
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.emit_wrong_surface"),
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE),
):
identity = _fake_identity()
identity.token_type = TokenType.OAUTH_EXTERNAL_SSO
mock_auth.return_value.authenticate.return_value = identity
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def view(*, auth_data):
pass
with pytest.raises(Forbidden):
view()
def test_guard_unregistered_token_type_returns_403(app):
router = _make_router(token_type=TokenType.OAUTH_ACCOUNT)
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE),
):
identity = _fake_identity()
identity.token_type = TokenType.OAUTH_EXTERNAL_SSO
mock_auth.return_value.authenticate.return_value = identity
@router.guard(scope=Scope.FULL)
def view(*, auth_data):
pass
with pytest.raises(Forbidden):
view()
def test_guard_no_bearer_returns_401(app):
router = _make_router()
with app.test_request_context("/test"):
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value=None):
@router.guard(scope=Scope.FULL)
def view(*, auth_data):
pass
with pytest.raises(Unauthorized):
view()
def test_guard_runs_prepare_steps_in_order(app):
order = []
def p1(b):
order.append("p1")
def p2(b):
order.append("p2")
router = _make_router(prepare=[p1, p2])
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
):
mock_auth.return_value.authenticate.return_value = _fake_identity()
@router.guard(scope=Scope.FULL)
def view(*, auth_data):
pass
view()
assert order == ["p1", "p2"]
def test_guard_resets_auth_ctx_on_exception(app):
router = _make_router()
reset_called = []
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value="tok"),
patch("controllers.openapi.auth.pipeline.reset_auth_ctx", side_effect=lambda t: reset_called.append(t)),
):
mock_auth.return_value.authenticate.return_value = _fake_identity()
@router.guard(scope=Scope.FULL)
def view(*, auth_data):
raise RuntimeError("boom")
with pytest.raises(RuntimeError):
view()
assert reset_called == ["tok"]
def test_router_rejects_token_type_on_wrong_edition(app):
pipeline = AuthPipeline(prepare=[], auth=[])
route = PipelineRoute(pipeline, required_edition=frozenset({Edition.EE}))
router = PipelineRouter({TokenType.OAUTH_EXTERNAL_SSO: route})
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE),
):
identity = _make_identity(token_type=TokenType.OAUTH_EXTERNAL_SSO)
mock_auth.return_value.authenticate.return_value = identity
@router.guard(scope=Scope.APPS_RUN)
def view(*, auth_data):
pass
with pytest.raises(Forbidden):
view()
def test_guard_populates_external_identity_from_subject_email(app):
from controllers.openapi.auth.data import ExternalIdentity
router = _make_router(token_type=TokenType.OAUTH_EXTERNAL_SSO)
received = {}
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
):
identity = _make_identity(
token_type=TokenType.OAUTH_EXTERNAL_SSO,
subject_email="user@sso.com",
subject_issuer="https://idp.example.com",
)
mock_auth.return_value.authenticate.return_value = identity
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}))
def view(*, auth_data):
received["data"] = auth_data
view()
assert isinstance(received["data"].external_identity, ExternalIdentity)
assert received["data"].external_identity.email == "user@sso.com"
assert received["data"].external_identity.issuer == "https://idp.example.com"
def test_guard_no_external_identity_when_subject_email_absent(app):
router = _make_router()
received = {}
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
):
mock_auth.return_value.authenticate.return_value = _make_identity(subject_email=None)
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def view(*, auth_data):
received["data"] = auth_data
view()
assert received["data"].external_identity is None
app = Flask(__name__)
with app.test_request_context("/x", method="POST"):
assert handler() == "ok"
assert seen == {"app_model": "APP", "caller": "CALLER", "caller_kind": "account"}

View File

@ -1,183 +0,0 @@
import uuid
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from controllers.openapi.auth.data import AuthData, ExternalIdentity
from controllers.openapi.auth.prepare import (
load_account,
load_app,
load_app_access_mode,
load_tenant,
resolve_external_user,
)
from libs.oauth_bearer import TokenType
def _make_auth_data(**kwargs) -> AuthData:
mock_fields = {k: kwargs.pop(k) for k in ("app", "tenant", "caller") if k in kwargs}
data = AuthData(
token_type=kwargs.pop("token_type", TokenType.OAUTH_ACCOUNT),
token_hash=kwargs.pop("token_hash", "testhash"),
scopes=kwargs.pop("scopes", frozenset()),
**kwargs,
)
for k, v in mock_fields.items():
setattr(data, k, v)
return data
def test_load_app_writes_app_to_data():
app = MagicMock()
app.status = "normal"
app.enable_api = True
data = _make_auth_data(path_params={"app_id": "abc"})
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app):
load_app(data)
assert data.app is app
def test_load_app_raises_not_found_when_missing():
data = _make_auth_data(path_params={"app_id": "missing"})
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=None):
with pytest.raises(NotFound):
load_app(data)
def test_load_app_raises_not_found_when_not_normal():
app = MagicMock()
app.status = "archived"
data = _make_auth_data(path_params={"app_id": "abc"})
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app):
with pytest.raises(NotFound):
load_app(data)
def test_load_app_raises_forbidden_when_api_disabled():
app = MagicMock()
app.status = "normal"
app.enable_api = False
data = _make_auth_data(path_params={"app_id": "abc"})
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app):
with pytest.raises(Forbidden):
load_app(data)
def test_load_tenant_writes_tenant():
app = MagicMock()
app.tenant_id = uuid.uuid4()
tenant = MagicMock()
tenant.status = "normal"
data = _make_auth_data(app=app)
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant):
load_tenant(data)
assert data.tenant is tenant
def test_load_tenant_raises_forbidden_when_archived():
from models.account import TenantStatus
app = MagicMock()
app.tenant_id = uuid.uuid4()
tenant = MagicMock()
tenant.status = TenantStatus.ARCHIVE
data = _make_auth_data(app=app)
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant):
with pytest.raises(Forbidden):
load_tenant(data)
def test_load_tenant_raises_forbidden_when_missing():
app = MagicMock()
app.tenant_id = uuid.uuid4()
data = _make_auth_data(app=app)
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=None):
with pytest.raises(Forbidden):
load_tenant(data)
def test_load_tenant_raises_500_when_app_not_loaded():
from werkzeug.exceptions import InternalServerError
data = _make_auth_data()
with pytest.raises(InternalServerError):
load_tenant(data)
def test_load_account_writes_caller():
account = MagicMock()
account_id = uuid.uuid4()
data = _make_auth_data(account_id=account_id)
with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=account):
load_account(data)
assert data.caller is account
assert data.caller_kind == "account"
def test_load_account_sets_current_tenant_when_tenant_present():
account = MagicMock()
tenant = MagicMock()
data = _make_auth_data(account_id=uuid.uuid4(), tenant=tenant)
with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=account):
load_account(data)
assert account.current_tenant is tenant
def test_load_account_raises_unauthorized_when_not_found():
data = _make_auth_data(account_id=uuid.uuid4())
with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=None):
with pytest.raises(Unauthorized):
load_account(data)
def test_resolve_external_user_writes_caller():
tenant = MagicMock()
app = MagicMock()
end_user = MagicMock()
ext = ExternalIdentity(email="user@sso.com")
data = _make_auth_data(tenant=tenant, app=app, external_identity=ext)
with patch("controllers.openapi.auth.prepare.EndUserService.get_or_create_end_user_by_type", return_value=end_user):
resolve_external_user(data)
assert data.caller is end_user
assert data.caller_kind == "end_user"
def test_resolve_external_user_raises_unauthorized_when_context_missing():
data = _make_auth_data(tenant=None, app=MagicMock(), external_identity=ExternalIdentity(email="u@s.com"))
with pytest.raises(Unauthorized):
resolve_external_user(data)
def test_load_app_access_mode_writes_mode():
from services.enterprise.enterprise_service import WebAppAccessMode
app = MagicMock()
app.id = "app-1"
settings = MagicMock()
settings.access_mode = "public"
data = _make_auth_data(app=app)
with patch(
"controllers.openapi.auth.prepare.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
return_value=settings,
):
load_app_access_mode(data)
assert data.app_access_mode == WebAppAccessMode.PUBLIC
def test_load_app_access_mode_writes_none_when_value_error():
app = MagicMock()
app.id = "app-1"
data = _make_auth_data(app=app)
with patch(
"controllers.openapi.auth.prepare.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
side_effect=ValueError("No data found."),
):
load_app_access_mode(data)
assert data.app_access_mode is None
def test_load_app_access_mode_no_op_when_app_missing():
data = _make_auth_data()
load_app_access_mode(data)
assert data.app_access_mode is None

View File

@ -26,7 +26,7 @@ from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
from controllers.openapi.auth.role_gate import require_workspace_role
from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx
from libs.oauth_bearer import AuthContext, Scope, SubjectType, reset_auth_ctx, set_auth_ctx
from models.account import TenantAccountRole
# Tokens from `_seed`'s `set_auth_ctx` calls, drained after each test so a
@ -55,7 +55,7 @@ def _account_ctx(account_id: uuid.UUID | None = None) -> AuthContext:
client_id="difyctl",
scopes=frozenset({Scope.FULL}),
token_id=uuid.uuid4(),
token_type=TokenType.OAUTH_ACCOUNT,
source="oauth_account",
expires_at=datetime.now(UTC),
token_hash="h1",
verified_tenants={},
@ -71,7 +71,7 @@ def _sso_ctx() -> AuthContext:
client_id="difyctl",
scopes=frozenset({Scope.APPS_RUN}),
token_id=uuid.uuid4(),
token_type=TokenType.OAUTH_EXTERNAL_SSO,
source="oauth_external_sso",
expires_at=datetime.now(UTC),
token_hash="h2",
verified_tenants={},

View File

@ -0,0 +1,64 @@
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import AppResolver
from models import TenantStatus
def _ctx(path_params: dict[str, str] | None) -> Context:
return Context(required_scope="apps:run", path_params=path_params or {})
def _app(*, status="normal", enable_api=True):
return SimpleNamespace(id="app1", tenant_id="t1", status=status, enable_api=enable_api)
def _tenant(*, status=TenantStatus.NORMAL):
return SimpleNamespace(id="t1", status=status)
def test_resolver_rejects_missing_path_param():
with pytest.raises(BadRequest):
AppResolver()(_ctx({}))
def test_resolver_rejects_empty_path_params():
# `Pipeline.guard` always seeds an empty dict when Flask reports no
# view args, so a missing `app_id` key surfaces here as BadRequest.
with pytest.raises(BadRequest):
AppResolver()(_ctx(None))
@patch("controllers.openapi.auth.steps.db")
def test_resolver_404_when_app_missing(db):
db.session.get.side_effect = [None]
with pytest.raises(NotFound):
AppResolver()(_ctx({"app_id": "x"}))
@patch("controllers.openapi.auth.steps.db")
def test_resolver_403_when_disabled(db):
db.session.get.side_effect = [_app(enable_api=False)]
with pytest.raises(Forbidden) as exc:
AppResolver()(_ctx({"app_id": "x"}))
assert "service_api_disabled" in str(exc.value.description)
@patch("controllers.openapi.auth.steps.db")
def test_resolver_403_when_tenant_archived(db):
db.session.get.side_effect = [_app(), _tenant(status=TenantStatus.ARCHIVE)]
with pytest.raises(Forbidden):
AppResolver()(_ctx({"app_id": "x"}))
@patch("controllers.openapi.auth.steps.db")
def test_resolver_populates_app_and_tenant(db):
db.session.get.side_effect = [_app(), _tenant()]
ctx = _ctx({"app_id": "x"})
AppResolver()(ctx)
assert ctx.app.id == "app1"
assert ctx.tenant.id == "t1"

View File

@ -0,0 +1,76 @@
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import AppAuthzCheck
from controllers.openapi.auth.strategies import AclStrategy, MembershipStrategy
from libs.oauth_bearer import SubjectType
def _ctx(*, subject_type, account_id="acc1"):
c = Context(required_scope="apps:run")
c.subject_type = subject_type
c.subject_email = "alice@example.com"
c.account_id = account_id
c.app = SimpleNamespace(id="app1")
c.tenant = SimpleNamespace(id="t1")
return c
@patch("controllers.openapi.auth.strategies.EnterpriseService")
def test_acl_strategy_private_calls_inner_api(ent):
ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode="private")
ent.WebAppAuth.is_user_allowed_to_access_webapp.return_value = True
assert AclStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_called_once_with(
user_id="acc1",
app_id="app1",
)
@pytest.mark.parametrize(
("access_mode", "subject_type", "expected"),
[
("public", SubjectType.ACCOUNT, True),
("public", SubjectType.EXTERNAL_SSO, True),
("sso_verified", SubjectType.ACCOUNT, True),
("sso_verified", SubjectType.EXTERNAL_SSO, True),
("private_all", SubjectType.ACCOUNT, True),
("private_all", SubjectType.EXTERNAL_SSO, False),
("private", SubjectType.EXTERNAL_SSO, False),
],
)
@patch("controllers.openapi.auth.strategies.EnterpriseService")
def test_acl_strategy_subject_mode_matrix(ent, access_mode, subject_type, expected):
"""Step 1 matrix: subject vs access-mode compatibility. No inner API call expected."""
ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode=access_mode)
account_id = "acc1" if subject_type == SubjectType.ACCOUNT else None
assert AclStrategy().authorize(_ctx(subject_type=subject_type, account_id=account_id)) is expected
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_not_called()
@patch("controllers.openapi.auth.strategies.TenantService.account_belongs_to_tenant")
@patch("controllers.openapi.auth.strategies.db")
def test_membership_strategy_uses_join_lookup(db_mock, member):
member.return_value = True
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
member.assert_called_once_with(db_mock.session, "acc1", "t1")
def test_membership_strategy_rejects_external_sso():
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)) is False
def test_app_authz_check_raises_when_strategy_denies():
deny = SimpleNamespace(authorize=lambda c: False)
with pytest.raises(Forbidden) as exc:
AppAuthzCheck(lambda: deny)(_ctx(subject_type=SubjectType.ACCOUNT))
assert "subject_no_app_access" in str(exc.value.description)
def test_app_authz_check_passes_when_strategy_allows():
allow = SimpleNamespace(authorize=lambda c: True)
AppAuthzCheck(lambda: allow)(_ctx(subject_type=SubjectType.ACCOUNT))

View File

@ -0,0 +1,83 @@
import uuid
from datetime import UTC, datetime
from unittest.mock import patch
import pytest
from flask import Flask
from werkzeug.exceptions import Unauthorized
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import BearerCheck
from libs.oauth_bearer import (
AuthContext,
InvalidBearerError,
Scope,
SubjectType,
reset_auth_ctx,
try_get_auth_ctx,
)
def _ctx(bearer_token: str | None) -> Context:
return Context(required_scope="apps:run", bearer_token=bearer_token)
def test_bearer_check_rejects_missing_header():
app = Flask(__name__)
with app.test_request_context(), pytest.raises(Unauthorized):
BearerCheck()(_ctx(None))
@patch("controllers.openapi.auth.steps.get_authenticator")
def test_bearer_check_rejects_unknown_prefix(get_auth):
get_auth.return_value.authenticate.side_effect = InvalidBearerError("invalid_bearer")
app = Flask(__name__)
with app.test_request_context(), pytest.raises(Unauthorized):
BearerCheck()(_ctx("xxx_abc"))
@patch("controllers.openapi.auth.steps.get_authenticator")
def test_bearer_check_populates_context_and_publishes_auth_ctx(get_auth):
tok_id = uuid.uuid4()
authn = AuthContext(
subject_type=SubjectType.ACCOUNT,
subject_email="a@x.com",
subject_issuer=None,
account_id=None,
client_id="difyctl",
scopes=frozenset({Scope.FULL}),
token_id=tok_id,
source="oauth-account",
expires_at=datetime.now(UTC),
token_hash="hash-1",
verified_tenants={},
)
get_auth.return_value.authenticate.return_value = authn
app = Flask(__name__)
ctx = _ctx("dfoa_abc")
with app.test_request_context():
BearerCheck()(ctx)
try:
assert ctx.subject_type == SubjectType.ACCOUNT
assert ctx.subject_email == "a@x.com"
assert ctx.scopes == frozenset({Scope.FULL})
assert ctx.source == "oauth-account"
assert ctx.token_id == tok_id
assert ctx.token_hash == "hash-1"
# BearerCheck must also publish the same identity on the
# openapi auth ContextVar so the surface gate + downstream
# handlers don't see two different identity sources between
# the decorator + pipeline paths. The reset token is parked
# on `ctx.auth_ctx_reset_token` for `Pipeline.guard` to
# consume in its `finally`.
published = try_get_auth_ctx()
assert published is authn
assert published.client_id == "difyctl"
assert ctx.auth_ctx_reset_token is not None
finally:
# In production `Pipeline.guard` resets the ContextVar; in
# this isolated step-level test we reset it ourselves so the
# value doesn't leak into the next test on the same worker.
assert ctx.auth_ctx_reset_token is not None
reset_auth_ctx(ctx.auth_ctx_reset_token)

View File

@ -0,0 +1,157 @@
"""Unit tests for WorkspaceMembershipCheck (Layer 0)."""
from __future__ import annotations
import uuid
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import WorkspaceMembershipCheck
from libs.oauth_bearer import SubjectType
def _ctx(*, subject_type, account_id, tenant_id, cached_verified_tenants=None, token_hash=None) -> Context:
c = Context(required_scope="apps:read")
c.subject_type = subject_type
c.account_id = account_id
c.tenant = SimpleNamespace(id=tenant_id) if tenant_id else None
c.cached_verified_tenants = cached_verified_tenants
c.token_hash = token_hash
return c
@pytest.fixture
def step():
return WorkspaceMembershipCheck()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_skips_when_enterprise_enabled(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = True
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id=str(uuid.uuid4()),
tenant_id=str(uuid.uuid4()),
cached_verified_tenants={},
token_hash="hash-1",
)
step(ctx) # no raise
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_skips_for_external_sso(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
ctx = _ctx(
subject_type=SubjectType.EXTERNAL_SSO,
account_id=None,
tenant_id=str(uuid.uuid4()),
cached_verified_tenants={},
token_hash="hash-1",
)
step(ctx) # no raise
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_uses_cached_ok(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={"t1": True},
token_hash="hash-1",
)
step(ctx)
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_uses_cached_denied(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={"t1": False},
token_hash="hash-1",
)
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
step(ctx)
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_denies_when_no_membership(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={},
token_hash="hash-1",
)
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
step(ctx)
mock_record.assert_called_once_with("hash-1", "t1", False)
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_denies_when_account_inactive(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.side_effect = [
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
MagicMock(scalar_one_or_none=MagicMock(return_value="banned")),
]
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={},
token_hash="hash-1",
)
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
step(ctx)
mock_record.assert_called_once_with("hash-1", "t1", False)
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_allows_active_member(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.side_effect = [
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
MagicMock(scalar_one_or_none=MagicMock(return_value="active")),
]
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={},
token_hash="hash-1",
)
step(ctx) # no raise
mock_record.assert_called_once_with("hash-1", "t1", True)

View File

@ -0,0 +1,77 @@
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from werkzeug.exceptions import Unauthorized
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import CallerMount
from controllers.openapi.auth.strategies import AccountMounter, EndUserMounter
from core.app.entities.app_invoke_entities import InvokeFrom
from libs.oauth_bearer import SubjectType
def _ctx(*, subject_type, account_id=None, subject_email=None):
c = Context(required_scope="apps:run")
c.subject_type = subject_type
c.account_id = account_id
c.subject_email = subject_email
c.app = SimpleNamespace(id="app1")
c.tenant = SimpleNamespace(id="t1")
return c
@patch("controllers.openapi.auth.strategies._login_as")
@patch("controllers.openapi.auth.strategies.db")
def test_account_mounter(db, login):
account = SimpleNamespace()
db.session.get.return_value = account
ctx = _ctx(subject_type=SubjectType.ACCOUNT, account_id="acc1")
AccountMounter().mount(ctx)
assert ctx.caller is account
assert ctx.caller.current_tenant is ctx.tenant
assert ctx.caller_kind == "account"
login.assert_called_once_with(account)
@patch("controllers.openapi.auth.strategies._login_as")
@patch("controllers.openapi.auth.strategies.EndUserService")
def test_end_user_mounter(svc, login):
eu = SimpleNamespace()
svc.get_or_create_end_user_by_type.return_value = eu
ctx = _ctx(subject_type=SubjectType.EXTERNAL_SSO, subject_email="a@x.com")
EndUserMounter().mount(ctx)
svc.get_or_create_end_user_by_type.assert_called_once_with(
InvokeFrom.OPENAPI,
tenant_id="t1",
app_id="app1",
user_id="a@x.com",
)
assert ctx.caller is eu
assert ctx.caller_kind == "end_user"
def test_caller_mount_dispatches_by_subject_type():
seen = {}
class Fake:
def __init__(self, st, tag):
self._st, self._tag = st, tag
def applies_to(self, st):
return st == self._st
def mount(self, ctx):
seen["who"] = self._tag
cm = CallerMount(
Fake(SubjectType.ACCOUNT, "acct"),
Fake(SubjectType.EXTERNAL_SSO, "sso"),
)
cm(_ctx(subject_type=SubjectType.EXTERNAL_SSO))
assert seen == {"who": "sso"}
def test_caller_mount_raises_when_none_applies():
with pytest.raises(Unauthorized):
CallerMount()(_ctx(subject_type=SubjectType.ACCOUNT))

View File

@ -0,0 +1,25 @@
import pytest
from werkzeug.exceptions import Forbidden
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import ScopeCheck
def _ctx(scopes, required):
c = Context(required_scope=required)
c.scopes = frozenset(scopes)
return c
def test_scope_check_passes_on_full():
ScopeCheck()(_ctx({"full"}, "apps:run"))
def test_scope_check_passes_on_explicit_match():
ScopeCheck()(_ctx({"apps:run"}, "apps:run"))
def test_scope_check_rejects_when_missing():
with pytest.raises(Forbidden) as exc:
ScopeCheck()(_ctx({"apps:read"}, "apps:run"))
assert "insufficient_scope" in str(exc.value.description)

View File

@ -0,0 +1,239 @@
"""Surface gate tests.
The gate has two attachment forms — decorator (`accept_subjects`) and
pipeline step (`SurfaceCheck`) — and both must:
- 403 on mismatched subject type with a canonical-path hint
- emit `openapi.wrong_surface_denied` once with the right payload
- pass-through on match
- raise RuntimeError (not 403) if the auth ContextVar is unset — that's
a wiring bug, not a user-driven failure
Identity is published via `libs.oauth_bearer.set_auth_ctx` / read with
`try_get_auth_ctx`. Tests wrap the publish in a `_publish_auth_ctx`
context manager so the ContextVar resets even when an assertion fails;
that keeps state from leaking into the next test on the same worker.
"""
from __future__ import annotations
import uuid
from collections.abc import Iterator
from contextlib import contextmanager
from datetime import UTC, datetime
from unittest.mock import patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import SurfaceCheck
from controllers.openapi.auth.surface_gate import _coerce_subject_type, accept_subjects, check_surface
from libs.oauth_bearer import AuthContext, Scope, SubjectType, reset_auth_ctx, set_auth_ctx
@contextmanager
def _publish_auth_ctx(ctx: AuthContext) -> Iterator[None]:
token = set_auth_ctx(ctx)
try:
yield
finally:
reset_auth_ctx(token)
def _account_ctx() -> AuthContext:
return AuthContext(
subject_type=SubjectType.ACCOUNT,
subject_email="user@example.com",
subject_issuer="dify:account",
account_id=uuid.uuid4(),
client_id="difyctl",
scopes=frozenset({Scope.FULL}),
token_id=uuid.uuid4(),
source="oauth_account",
expires_at=datetime.now(UTC),
token_hash="h1",
verified_tenants={},
)
def _sso_ctx() -> AuthContext:
return AuthContext(
subject_type=SubjectType.EXTERNAL_SSO,
subject_email="sso@partner.com",
subject_issuer="https://idp.partner.com",
account_id=None,
client_id="difyctl",
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}),
token_id=uuid.uuid4(),
source="oauth_external_sso",
expires_at=datetime.now(UTC),
token_hash="h2",
verified_tenants={},
)
# ---------------------------------------------------------------------------
# check_surface — shared core
# ---------------------------------------------------------------------------
def test_check_surface_passes_when_subject_in_accepted():
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_account_ctx()):
check_surface(frozenset({SubjectType.ACCOUNT})) # no raise
def test_check_surface_rejects_on_wrong_subject_and_emits_audit():
app = Flask(__name__)
with app.test_request_context("/openapi/v1/permitted-external-apps"), _publish_auth_ctx(_account_ctx()):
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
with pytest.raises(Forbidden) as exc:
check_surface(frozenset({SubjectType.EXTERNAL_SSO}))
assert "wrong_surface" in exc.value.description
# canonical-path hint should point at the caller's surface,
# not the surface they were rejected from
assert "/openapi/v1/apps" in exc.value.description
emit.assert_called_once()
kwargs = emit.call_args.kwargs
assert kwargs["subject_type"] == SubjectType.ACCOUNT.value
assert kwargs["attempted_path"] == "/openapi/v1/permitted-external-apps"
assert kwargs["client_id"] == "difyctl"
assert kwargs["token_id"] is not None
def test_check_surface_rejects_sso_on_account_surface():
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_sso_ctx()):
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
with pytest.raises(Forbidden):
check_surface(frozenset({SubjectType.ACCOUNT}))
kwargs = emit.call_args.kwargs
assert kwargs["subject_type"] == SubjectType.EXTERNAL_SSO.value
def test_check_surface_runtime_error_when_auth_ctx_missing():
"""Missing auth ContextVar means the bearer layer didn't run — wiring
bug, not a user-driven failure. Surface as RuntimeError (loud) so a
future refactor doesn't accidentally let a route skip authentication
and return a 403 that looks identical to a legitimate wrong-surface
deny.
"""
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps"):
with pytest.raises(RuntimeError):
check_surface(frozenset({SubjectType.ACCOUNT}))
# ---------------------------------------------------------------------------
# @accept_subjects — decorator form
# ---------------------------------------------------------------------------
def _make_app() -> Flask:
app = Flask(__name__)
@app.route("/account-only")
@accept_subjects(SubjectType.ACCOUNT)
def _account_only():
return "ok"
@app.route("/external-only")
@accept_subjects(SubjectType.EXTERNAL_SSO)
def _external_only():
return "ok"
return app
def test_accept_subjects_decorator_passes_on_match():
app = _make_app()
with app.test_request_context("/account-only"), _publish_auth_ctx(_account_ctx()):
# Re-route through the decorated function by reaching for view_function
view = app.view_functions["_account_only"]
assert view() == "ok"
def test_accept_subjects_decorator_403_on_miss():
app = _make_app()
with app.test_request_context("/external-only"), _publish_auth_ctx(_account_ctx()):
view = app.view_functions["_external_only"]
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface"):
with pytest.raises(Forbidden):
view()
# ---------------------------------------------------------------------------
# SurfaceCheck — pipeline step form
# ---------------------------------------------------------------------------
def _pipeline_ctx() -> Context:
# SurfaceCheck reads ``request.path`` from Flask's global request — set up
# via ``app.test_request_context`` in the calling tests — not from Context.
return Context(required_scope=Scope.APPS_RUN)
def test_surface_check_passes_on_match():
step = SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT}))
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()):
step(_pipeline_ctx()) # no raise
def test_surface_check_rejects_on_miss_and_emits_audit():
step = SurfaceCheck(accepted=frozenset({SubjectType.EXTERNAL_SSO}))
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()):
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
with pytest.raises(Forbidden):
step(_pipeline_ctx())
emit.assert_called_once()
# ---------------------------------------------------------------------------
# _coerce_subject_type — normalises whatever sat on ctx.subject_type
# ---------------------------------------------------------------------------
#
# The gate reads `ctx.subject_type` via `getattr(..., None)`, so the value
# could be a real enum (happy path), a raw string (e.g. rehydrated from a
# dict-shaped context), `None` (attribute missing), or something unexpected
# from a buggy upstream. The coercer must collapse all of that to
# `SubjectType | None` so `check_surface` can do a clean set-membership
# check and emit a clean audit payload.
def test_coerce_subject_type_returns_none_for_none():
assert _coerce_subject_type(None) is None
def test_coerce_subject_type_returns_enum_instance_unchanged():
# Identity matters: we don't want to round-trip through the string
# constructor for an already-valid enum.
assert _coerce_subject_type(SubjectType.ACCOUNT) is SubjectType.ACCOUNT
assert _coerce_subject_type(SubjectType.EXTERNAL_SSO) is SubjectType.EXTERNAL_SSO
@pytest.mark.parametrize(
("raw", "expected"),
[
("account", SubjectType.ACCOUNT),
("external_sso", SubjectType.EXTERNAL_SSO),
],
)
def test_coerce_subject_type_parses_known_strings(raw: str, expected: SubjectType):
assert _coerce_subject_type(raw) is expected
def test_coerce_subject_type_raises_on_unknown_string():
# Unknown strings reach `SubjectType(raw)` which raises ValueError.
# We surface that loudly rather than silently returning None, because
# a string that *looks* like a subject type but isn't is almost
# certainly an upstream bug worth catching.
with pytest.raises(ValueError):
_coerce_subject_type("not_a_subject")
@pytest.mark.parametrize("raw", [123, 1.5, b"account", object(), ["account"], {"account"}])
def test_coerce_subject_type_returns_none_for_non_string_non_enum(raw: object):
assert _coerce_subject_type(raw) is None

View File

@ -1,142 +0,0 @@
import uuid
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, Unauthorized
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.verify import (
check_acl,
check_app_access,
check_membership,
check_private_app_permission,
check_scope,
)
from libs.oauth_bearer import Scope, TokenType
from models.account import Tenant
from models.model import App
from services.enterprise.enterprise_service import WebAppAccessMode
def _data(**kwargs) -> AuthData:
defaults: dict = {"token_type": TokenType.OAUTH_ACCOUNT, "token_hash": "hash", "scopes": frozenset({Scope.FULL})}
defaults.update(kwargs)
return AuthData(**defaults)
def test_check_scope_passes_when_required_is_none():
check_scope(_data(required_scope=None))
def test_check_scope_passes_when_full_in_scopes():
check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.FULL})))
def test_check_scope_passes_when_exact_scope_present():
check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.APPS_RUN})))
def test_check_scope_raises_forbidden_when_scope_missing():
with pytest.raises(Forbidden, match="insufficient_scope"):
check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.APPS_READ})))
def test_check_membership_raises_unauthorized_when_tenant_none():
with pytest.raises(Unauthorized):
check_membership(_data(tenant=None))
def test_check_membership_calls_check_workspace_membership():
tenant = MagicMock(spec=Tenant)
tenant.id = "tenant-1"
data = _data(
account_id=uuid.uuid4(),
token_hash="myhash",
tenants={"tenant-1": True},
tenant=tenant,
)
with patch("controllers.openapi.auth.verify.check_workspace_membership") as mock_cwm:
check_membership(data)
mock_cwm.assert_called_once_with(
account_id=data.account_id,
tenant_id="tenant-1",
token_hash="myhash",
membership_cache=data.tenants,
)
def test_check_app_access_passes_when_tenant_none():
check_app_access(_data(tenant=None))
def test_check_app_access_passes_when_member():
tenant = MagicMock(spec=Tenant)
tenant.id = "t1"
data = _data(account_id=uuid.uuid4(), tenant=tenant)
with patch("controllers.openapi.auth.verify.TenantService.account_belongs_to_tenant", return_value=True):
check_app_access(data)
def test_check_app_access_raises_when_not_member():
tenant = MagicMock(spec=Tenant)
tenant.id = "t1"
data = _data(account_id=uuid.uuid4(), tenant=tenant)
with patch("controllers.openapi.auth.verify.TenantService.account_belongs_to_tenant", return_value=False):
with pytest.raises(Forbidden, match="subject_no_app_access"):
check_app_access(data)
def test_check_acl_raises_when_app_or_mode_missing():
with pytest.raises(Forbidden):
check_acl(_data(app=None, app_access_mode=None))
def test_check_acl_account_allowed_for_public():
app = MagicMock(spec=App)
data = _data(token_type=TokenType.OAUTH_ACCOUNT, app=app, app_access_mode=WebAppAccessMode.PUBLIC)
check_acl(data)
def test_check_acl_external_sso_blocked_for_private():
app = MagicMock(spec=App)
data = _data(
token_type=TokenType.OAUTH_EXTERNAL_SSO,
app=app,
app_access_mode=WebAppAccessMode.PRIVATE,
)
with pytest.raises(Forbidden, match="subject_not_allowed_for_access_mode"):
check_acl(data)
def test_check_acl_external_sso_allowed_for_sso_verified():
app = MagicMock(spec=App)
data = _data(
token_type=TokenType.OAUTH_EXTERNAL_SSO,
app=app,
app_access_mode=WebAppAccessMode.SSO_VERIFIED,
)
check_acl(data)
def test_check_private_app_permission_raises_when_app_none():
with pytest.raises(Forbidden):
check_private_app_permission(_data(app=None))
def test_check_private_app_permission_raises_when_user_not_allowed():
app = MagicMock(spec=App)
app.id = "app-1"
data = _data(account_id=uuid.uuid4(), app=app)
target = "controllers.openapi.auth.verify.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp"
with patch(target, return_value=False):
with pytest.raises(Forbidden, match="user_not_allowed_for_private_app"):
check_private_app_permission(data)
def test_check_private_app_permission_passes_when_allowed():
app = MagicMock(spec=App)
app.id = "app-1"
data = _data(account_id=uuid.uuid4(), app=app)
target = "controllers.openapi.auth.verify.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp"
with patch(target, return_value=True):
check_private_app_permission(data)

View File

@ -1,36 +1,20 @@
import uuid
import pytest
from flask import Flask
from controllers.openapi import bp as openapi_bp
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.pipeline import PipelineRouter
from libs.oauth_bearer import Scope, TokenType
def _stub_execute(self, args, kwargs, view, *, scope=None, allowed_token_types=None, edition=None):
"""Bypass all auth logic; inject minimal AuthData and call the view directly."""
kwargs["auth_data"] = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
account_id=uuid.uuid4(),
token_hash="test",
token_id=uuid.uuid4(),
scopes=frozenset({Scope.FULL}),
required_scope=scope,
)
return view(*args, **kwargs)
from controllers.openapi.auth.pipeline import Pipeline
@pytest.fixture
def bypass_pipeline(monkeypatch):
"""Stub PipelineRouter._execute so endpoints skip real auth at request time.
"""Stub Pipeline.run so endpoint decoration does not invoke real auth.
Module-level @auth_router.guard(...) captures the real router at import
time — patching guard itself does nothing. Patching _execute on the class
is the seam that fires at request time.
Module-level @OAUTH_BEARER_PIPELINE.guard(...) captures the real
pipeline at import time; mocking the module attribute does not undo
that. Patching Pipeline.run on the class is the bypass that actually
works.
"""
monkeypatch.setattr(PipelineRouter, "_execute", _stub_execute)
monkeypatch.setattr(Pipeline, "run", lambda self, ctx: None)
@pytest.fixture

View File

@ -86,7 +86,7 @@ def test_subject_match_for_account_filters_by_account_id():
"""Account subject scopes queries via account_id."""
import uuid as _uuid
from libs.oauth_bearer import AuthContext, SubjectType, TokenType
from libs.oauth_bearer import AuthContext, SubjectType
from services.oauth_device_flow import subject_match_clauses
aid = _uuid.uuid4()
@ -98,7 +98,7 @@ def test_subject_match_for_account_filters_by_account_id():
client_id="difyctl",
scopes=frozenset({"full"}),
token_id=_uuid.uuid4(),
token_type=TokenType.OAUTH_ACCOUNT,
source="oauth_account",
expires_at=None,
token_hash="h1",
verified_tenants={},
@ -116,7 +116,7 @@ def test_subject_match_for_external_sso_filters_by_email_and_issuer():
"""
import uuid as _uuid
from libs.oauth_bearer import AuthContext, SubjectType, TokenType
from libs.oauth_bearer import AuthContext, SubjectType
from services.oauth_device_flow import subject_match_clauses
ctx = AuthContext(
@ -127,7 +127,7 @@ def test_subject_match_for_external_sso_filters_by_email_and_issuer():
client_id="difyctl",
scopes=frozenset({"apps:run"}),
token_id=_uuid.uuid4(),
token_type=TokenType.OAUTH_EXTERNAL_SSO,
source="oauth_external_sso",
expires_at=None,
token_hash="h1",
verified_tenants={},

View File

@ -57,11 +57,7 @@ def test_stop_task_endpoint_registered(openapi_app):
def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, monkeypatch):
import uuid
from controllers.openapi.app_run import AppRunTaskStopApi
from controllers.openapi.auth.data import AuthData
from libs.oauth_bearer import Scope, TokenType
queue_mock = Mock()
graph_mock = Mock()
@ -73,23 +69,15 @@ def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, mo
monkeypatch.setattr(run_module, "GraphEngineManager", graph_mock)
monkeypatch.setattr(run_module, "redis_client", object())
auth_data = AuthData.model_construct(
token_type=TokenType.OAUTH_ACCOUNT,
account_id=uuid.uuid4(),
token_hash="test",
scopes=frozenset({Scope.FULL}),
app=SimpleNamespace(id="app-1", tenant_id="t-1"),
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
)
api = AppRunTaskStopApi()
with app.test_request_context("/openapi/v1/apps/app-1/tasks/task-1/stop", method="POST"):
result = api.post.__wrapped__(
api,
app_id="app-1",
task_id="task-1",
auth_data=auth_data,
app_model=SimpleNamespace(id="app-1", tenant_id="t-1"),
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
)
queue_mock.set_stop_flag_no_user_check.assert_called_once_with("task-1")

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import json
import sys
import uuid
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock
@ -12,23 +11,9 @@ from unittest.mock import Mock
import pytest
from werkzeug.exceptions import NotFound
from controllers.openapi.auth.data import AuthData
from libs.oauth_bearer import Scope, TokenType
from models.human_input import RecipientType
def _make_auth_data(app_model, caller, caller_kind):
return AuthData.model_construct(
token_type=TokenType.OAUTH_ACCOUNT,
account_id=uuid.uuid4(),
token_hash="test",
scopes=frozenset({Scope.FULL}),
app=app_model,
caller=caller,
caller_kind=caller_kind,
)
class TestOpenApiHumanInputFormGet:
def test_get_success(self, app, bypass_pipeline, monkeypatch):
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
@ -58,14 +43,15 @@ class TestOpenApiHumanInputFormGet:
api = OpenApiWorkflowHumanInputFormApi()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
caller = SimpleNamespace(id="acct-1")
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
resp = api.get.__wrapped__(
api,
app_id="app-1",
form_token="tok-1",
auth_data=_make_auth_data(app_model, caller, "account"),
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
)
payload = json.loads(resp.get_data(as_text=True))
@ -85,7 +71,6 @@ class TestOpenApiHumanInputFormGet:
api = OpenApiWorkflowHumanInputFormApi()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
caller = SimpleNamespace(id="acct-1")
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/bad"):
with pytest.raises(NotFound):
@ -93,7 +78,9 @@ class TestOpenApiHumanInputFormGet:
api,
app_id="app-1",
form_token="bad",
auth_data=_make_auth_data(app_model, caller, "account"),
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
)
def test_get_form_wrong_app(self, app, bypass_pipeline, monkeypatch):
@ -110,7 +97,6 @@ class TestOpenApiHumanInputFormGet:
api = OpenApiWorkflowHumanInputFormApi()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
caller = SimpleNamespace(id="acct-1")
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
with pytest.raises(NotFound):
@ -118,7 +104,9 @@ class TestOpenApiHumanInputFormGet:
api,
app_id="app-1",
form_token="tok-1",
auth_data=_make_auth_data(app_model, caller, "account"),
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
)
def test_get_form_wrong_surface(self, app, bypass_pipeline, monkeypatch):
@ -138,7 +126,6 @@ class TestOpenApiHumanInputFormGet:
api = OpenApiWorkflowHumanInputFormApi()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
caller = SimpleNamespace(id="acct-1")
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
with pytest.raises(NotFound):
@ -146,7 +133,9 @@ class TestOpenApiHumanInputFormGet:
api,
app_id="app-1",
form_token="tok-1",
auth_data=_make_auth_data(app_model, caller, "account"),
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
)
@ -183,7 +172,9 @@ class TestOpenApiHumanInputFormPost:
api,
app_id="app-1",
form_token="tok-1",
auth_data=_make_auth_data(app_model, caller, "account"),
app_model=app_model,
caller=caller,
caller_kind="account",
)
service_mock.submit_form_by_token.assert_called_once_with(
@ -220,7 +211,9 @@ class TestOpenApiHumanInputFormPost:
api,
app_id="app-1",
form_token="tok-1",
auth_data=_make_auth_data(app_model, caller, "end_user"),
app_model=app_model,
caller=caller,
caller_kind="end_user",
)
service_mock.submit_form_by_token.assert_called_once_with(

View File

@ -3,30 +3,15 @@
from __future__ import annotations
import sys
import uuid
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from werkzeug.exceptions import NotFound
from controllers.openapi.auth.data import AuthData
from libs.oauth_bearer import Scope, TokenType
from models.enums import CreatorUserRole
def _make_auth_data(app_model, caller, caller_kind):
return AuthData.model_construct(
token_type=TokenType.OAUTH_ACCOUNT,
account_id=uuid.uuid4(),
token_hash="test",
scopes=frozenset({Scope.FULL}),
app=app_model,
caller=caller,
caller_kind=caller_kind,
)
def _make_workflow_run(
*,
app_id="app-1",
@ -65,7 +50,6 @@ class TestOpenApiWorkflowEventsApi:
from models.model import AppMode
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
caller = SimpleNamespace(id="acct-1")
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
with pytest.raises(NotFound):
@ -73,7 +57,9 @@ class TestOpenApiWorkflowEventsApi:
api,
app_id="app-1",
task_id="wf-run-1",
auth_data=_make_auth_data(app_model, caller, "account"),
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
)
def test_not_found_when_run_belongs_to_different_app(self, app, bypass_pipeline, monkeypatch):
@ -91,7 +77,6 @@ class TestOpenApiWorkflowEventsApi:
from models.model import AppMode
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
caller = SimpleNamespace(id="acct-1")
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
with pytest.raises(NotFound):
@ -99,7 +84,9 @@ class TestOpenApiWorkflowEventsApi:
api,
app_id="app-1",
task_id="wf-run-1",
auth_data=_make_auth_data(app_model, caller, "account"),
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
)
def test_account_caller_checks_created_by_account(self, app, bypass_pipeline, monkeypatch):
@ -128,7 +115,6 @@ class TestOpenApiWorkflowEventsApi:
from models.model import AppMode
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
caller = SimpleNamespace(id="acct-1")
api = self._get_api()
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
@ -137,7 +123,9 @@ class TestOpenApiWorkflowEventsApi:
api,
app_id="app-1",
task_id="wf-run-1",
auth_data=_make_auth_data(app_model, caller, "account"),
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
)
assert resp.mimetype == "text/event-stream"
@ -155,7 +143,6 @@ class TestOpenApiWorkflowEventsApi:
from models.model import AppMode
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
caller = SimpleNamespace(id="acct-1")
api = self._get_api()
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
@ -164,7 +151,9 @@ class TestOpenApiWorkflowEventsApi:
api,
app_id="app-1",
task_id="wf-run-1",
auth_data=_make_auth_data(app_model, caller, "account"),
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
)
def test_end_user_caller_checks_created_by_end_user(self, app, bypass_pipeline, monkeypatch):
@ -190,7 +179,6 @@ class TestOpenApiWorkflowEventsApi:
from models.model import AppMode
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
caller = SimpleNamespace(id="eu-1")
api = self._get_api()
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
@ -198,7 +186,9 @@ class TestOpenApiWorkflowEventsApi:
api,
app_id="app-1",
task_id="wf-run-1",
auth_data=_make_auth_data(app_model, caller, "end_user"),
app_model=app_model,
caller=SimpleNamespace(id="eu-1"),
caller_kind="end_user",
)
assert resp.mimetype == "text/event-stream"
@ -232,7 +222,6 @@ class TestOpenApiWorkflowEventsApi:
from models.model import AppMode
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
caller = SimpleNamespace(id="acct-1")
api = self._get_api()
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
@ -240,7 +229,9 @@ class TestOpenApiWorkflowEventsApi:
api,
app_id="app-1",
task_id="wf-run-1",
auth_data=_make_auth_data(app_model, caller, "account"),
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
)
assert resp.mimetype == "text/event-stream"
chunks = list(resp.response)

View File

@ -38,7 +38,7 @@ from controllers.openapi.workspaces import (
WorkspaceMembersApi,
WorkspaceSwitchApi,
)
from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx
from libs.oauth_bearer import AuthContext, Scope, SubjectType, reset_auth_ctx, set_auth_ctx
from models.account import AccountStatus, TenantAccountRole
from services.errors.account import (
AccountAlreadyInTenantError,
@ -97,25 +97,13 @@ def _auth_ctx(account_id: uuid.UUID | None = None) -> AuthContext:
client_id="difyctl",
scopes=frozenset({Scope.FULL}),
token_id=uuid.uuid4(),
token_type=TokenType.OAUTH_ACCOUNT,
source="oauth_account",
expires_at=datetime.now(UTC),
token_hash="h",
verified_tenants={},
)
def _auth_data(account_id: uuid.UUID) -> AuthData:
from controllers.openapi.auth.data import AuthData
from libs.oauth_bearer import Scope, TokenType
return AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
account_id=account_id,
token_hash="testhash",
scopes=frozenset({Scope.FULL}),
)
def _account(account_id: str = "acct-1", email: str = "u@example.com") -> SimpleNamespace:
return SimpleNamespace(
id=account_id,
@ -268,7 +256,7 @@ def test_switch_returns_workspace_detail_with_current_true(app, bypass_pipeline,
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/switch", method="POST"):
_seed(_auth_ctx(account_id=acct_id))
body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
body, status = api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
assert status == 200
assert body["id"] == ws_id
@ -296,7 +284,7 @@ def test_switch_404s_when_service_raises_account_not_link_tenant(app, bypass_pip
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/switch", method="POST"):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(NotFound):
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
# ---------------------------------------------------------------------------
@ -330,7 +318,7 @@ def test_members_list_returns_normalized_rows(app, bypass_pipeline, monkeypatch)
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members"):
_seed(_auth_ctx(account_id=acct_id))
body, status = api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
body, status = api.get.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
assert status == 200
assert body["page"] == 1
@ -372,7 +360,7 @@ def test_members_list_paginates_with_query_params(app, bypass_pipeline, monkeypa
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members?page=2&limit=2"):
_seed(_auth_ctx(account_id=acct_id))
body, status = api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
body, status = api.get.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
assert status == 200
assert body["page"] == 2
@ -395,7 +383,7 @@ def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypa
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members?pg=2"):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(BadRequest):
api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
api.get.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
# ---------------------------------------------------------------------------
@ -433,7 +421,7 @@ def test_invite_happy_path_returns_invite_url_and_member_id(app, bypass_pipeline
content_type="application/json",
):
_seed(_auth_ctx(account_id=acct_id))
body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
body, status = api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
assert status == 201
assert body["result"] == "success"
@ -518,7 +506,7 @@ def test_invite_blocked_by_saas_members_cap(app, bypass_pipeline, monkeypatch):
with _invite_request(app, ws_id, acct_id):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(Forbidden) as exc_info:
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
body = exc_info.value.response.json
assert body["code"] == "members.limit_exceeded"
@ -564,7 +552,7 @@ def test_invite_blocked_by_ee_workspace_members_license(app, bypass_pipeline, mo
with _invite_request(app, ws_id, acct_id):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(Forbidden) as exc_info:
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
body = exc_info.value.response.json
assert body["code"] == "workspace_members.license_exceeded"
@ -603,7 +591,7 @@ def test_invite_ce_passes_when_both_caps_disabled(app, bypass_pipeline, monkeypa
with _invite_request(app, ws_id, acct_id):
_seed(_auth_ctx(account_id=acct_id))
body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
body, status = api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
assert status == 201
assert body["email"] == "new@example.com"
@ -632,7 +620,7 @@ def test_invite_400_when_already_in_tenant(app, bypass_pipeline, monkeypatch):
):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(BadRequest):
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
# ---------------------------------------------------------------------------
@ -665,8 +653,10 @@ def test_delete_member_happy_path(app, bypass_pipeline, monkeypatch):
method="DELETE",
):
_seed(_auth_ctx(account_id=acct_id))
body, status = api.delete.__wrapped__.__wrapped__(
api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id)
body, status = api.delete.__wrapped__.__wrapped__.__wrapped__(
api,
workspace_id=ws_id,
member_id=member_id,
)
assert status == 200
@ -707,11 +697,10 @@ def test_delete_member_exception_mapping(app, bypass_pipeline, monkeypatch, exc,
):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(expected):
api.delete.__wrapped__.__wrapped__(
api.delete.__wrapped__.__wrapped__.__wrapped__(
api,
workspace_id=ws_id,
member_id=member_id,
auth_data=_auth_data(acct_id),
)
@ -734,11 +723,10 @@ def test_delete_member_404_when_member_missing(app, bypass_pipeline, monkeypatch
):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(NotFound):
api.delete.__wrapped__.__wrapped__(
api.delete.__wrapped__.__wrapped__.__wrapped__(
api,
workspace_id=ws_id,
member_id=member_id,
auth_data=_auth_data(acct_id),
)
@ -774,8 +762,10 @@ def test_update_role_happy_path(app, bypass_pipeline, monkeypatch):
content_type="application/json",
):
_seed(_auth_ctx(account_id=acct_id))
body, status = api.put.__wrapped__.__wrapped__(
api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id)
body, status = api.put.__wrapped__.__wrapped__.__wrapped__(
api,
workspace_id=ws_id,
member_id=member_id,
)
assert status == 200
@ -820,11 +810,10 @@ def test_update_role_exception_mapping(app, bypass_pipeline, monkeypatch, exc, e
):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(expected):
api.put.__wrapped__.__wrapped__(
api.put.__wrapped__.__wrapped__.__wrapped__(
api,
workspace_id=ws_id,
member_id=member_id,
auth_data=_auth_data(acct_id),
)
@ -858,8 +847,9 @@ def test_non_member_caller_gets_404_on_switch(app, bypass_pipeline, monkeypatch)
# Strip only the bearer + surface-gate wrappers; keep the role gate.
# Decorator stack (innermost → outermost):
# role_gate → accept_subjects → validate_bearer
# `post.__wrapped__` is now the role-gate wrapper directly (auth_router.guard is the only outer wrapper).
gated = api.post.__wrapped__
# So `post.__wrapped__` unwraps validate_bearer; we then unwrap
# accept_subjects to land on the role-gate wrapper.
gated = api.post.__wrapped__.__wrapped__
with pytest.raises(NotFound):
gated(api, workspace_id=ws_id)
@ -891,7 +881,7 @@ def test_load_tenant_rejects_archived_workspace(app, bypass_pipeline, monkeypatc
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members"):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(NotFound):
api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
api.get.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)
# ---------------------------------------------------------------------------
@ -925,4 +915,4 @@ def test_invite_400_when_register_error(app, bypass_pipeline, monkeypatch):
):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(BadRequest):
api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id)

View File

@ -1,12 +1,9 @@
from types import SimpleNamespace
from typing import cast
from agenton.compositor import CompositorSessionSnapshot
from clients.agent_backend import (
AgentBackendRunEventAdapter,
AgentBackendStreamInternalEvent,
CleanupLayerSpec,
FakeAgentBackendRunClient,
FakeAgentBackendScenario,
)
@ -16,10 +13,9 @@ from core.workflow.nodes.agent_v2.binding_resolver import WorkflowAgentBindingBu
from core.workflow.nodes.agent_v2.entities import DifyAgentNodeData
from core.workflow.nodes.agent_v2.output_adapter import WorkflowAgentOutputAdapter
from core.workflow.nodes.agent_v2.runtime_request_builder import WorkflowAgentRuntimeRequestBuilder
from core.workflow.nodes.agent_v2.session_store import WorkflowAgentRuntimeSessionStore, WorkflowAgentSessionScope
from graphon.entities import GraphInitParams
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.node_events import PauseRequestedEvent, StreamCompletedEvent
from graphon.node_events import StreamCompletedEvent
from graphon.runtime import GraphRuntimeState
from graphon.variables.segments import StringSegment
from models.agent import Agent, AgentConfigSnapshot, WorkflowAgentNodeBinding
@ -88,47 +84,7 @@ class FakeBindingResolver(WorkflowAgentBindingResolver):
return WorkflowAgentBindingBundle(binding=self.binding, agent=self.agent, snapshot=self.snapshot)
class FakeSessionStore:
def __init__(self, snapshot: CompositorSessionSnapshot | None = None) -> None:
self.loaded_snapshot = snapshot
self.saved: list[
tuple[
WorkflowAgentSessionScope,
str,
CompositorSessionSnapshot | None,
list[CleanupLayerSpec],
]
] = []
self.cleaned: list[tuple[WorkflowAgentSessionScope, str | None]] = []
def load_active_snapshot(self, scope: WorkflowAgentSessionScope) -> CompositorSessionSnapshot | None:
return self.loaded_snapshot
def save_active_snapshot(
self,
*,
scope: WorkflowAgentSessionScope,
backend_run_id: str,
snapshot: CompositorSessionSnapshot | None,
composition_layer_specs: list[CleanupLayerSpec],
) -> None:
self.saved.append((scope, backend_run_id, snapshot, list(composition_layer_specs)))
def mark_cleaned(
self,
*,
scope: WorkflowAgentSessionScope,
backend_run_id: str | None = None,
) -> None:
self.cleaned.append((scope, backend_run_id))
def _node(
*,
scenario: FakeAgentBackendScenario = FakeAgentBackendScenario.SUCCESS,
agent_backend_client: FakeAgentBackendRunClient | None = None,
session_store: FakeSessionStore | None = None,
) -> DifyAgentNode:
def _node(*, scenario: FakeAgentBackendScenario = FakeAgentBackendScenario.SUCCESS) -> DifyAgentNode:
graph_init_params = GraphInitParams(
workflow_id="workflow-1",
graph_config={"nodes": [], "edges": []},
@ -150,7 +106,6 @@ def _node(
def is_owned_by_tenant(self, *, file_id: str, tenant_id: str) -> bool:
return True
client = agent_backend_client or FakeAgentBackendRunClient(scenario=scenario)
return DifyAgentNode(
node_id="agent-node",
data=DifyAgentNodeData.model_validate({"type": BuiltinNodeTypes.AGENT, "version": "2"}),
@ -158,12 +113,11 @@ def _node(
graph_runtime_state=cast(GraphRuntimeState, SimpleNamespace(variable_pool=FakeVariablePool())),
binding_resolver=FakeBindingResolver(),
runtime_request_builder=WorkflowAgentRuntimeRequestBuilder(credentials_provider=FakeCredentialsProvider()),
agent_backend_client=client,
agent_backend_client=FakeAgentBackendRunClient(scenario=scenario),
event_adapter=AgentBackendRunEventAdapter(),
output_adapter=WorkflowAgentOutputAdapter(),
type_checker=PerOutputTypeChecker(file_validator=_AlwaysAllowFileValidator()),
failure_orchestrator=OutputFailureOrchestrator(),
session_store=cast(WorkflowAgentRuntimeSessionStore | None, session_store),
)
@ -178,7 +132,7 @@ def test_agent_node_run_maps_successful_agent_backend_run_to_node_result():
assert agent_log["agent_backend"]["run_id"] == "fake-run-1"
assert agent_log["agent_backend"]["status"] == "succeeded"
assert result.process_data["agent_id"] == "agent-1"
assert result.inputs["agent_backend_request"]["composition"]["layers"][5]["config"]["credentials"] == "[REDACTED]"
assert result.inputs["agent_backend_request"]["composition"]["layers"][4]["config"]["credentials"] == "[REDACTED]"
def test_agent_node_run_maps_failed_agent_backend_run_to_node_result():
@ -191,126 +145,6 @@ def test_agent_node_run_maps_failed_agent_backend_run_to_node_result():
assert result.error_type == "unit_test"
def test_agent_node_failed_run_marks_session_cleaned_to_prevent_stale_reuse():
"""A failed agent run must retire the local ACTIVE session row so a workflow
loop back into the same Agent node does not resume from a stale snapshot."""
existing_snapshot = CompositorSessionSnapshot(layers=[])
store = FakeSessionStore(snapshot=existing_snapshot)
events = list(_node(scenario=FakeAgentBackendScenario.FAILED, session_store=store)._run())
assert len(events) == 1
assert store.cleaned, "failed agent run should mark the session cleaned"
cleaned_scope, cleaned_backend_run_id = store.cleaned[0]
assert cleaned_scope.workflow_run_id == "workflow-run-1"
assert cleaned_backend_run_id == "fake-run-1"
# A failed run does not produce a fresh snapshot to persist.
assert store.saved == []
def test_agent_node_saves_success_snapshot_and_reuses_existing_snapshot():
existing_snapshot = CompositorSessionSnapshot(layers=[])
store = FakeSessionStore(snapshot=existing_snapshot)
client = FakeAgentBackendRunClient()
node = _node(agent_backend_client=client, session_store=store)
events = list(node._run())
assert len(events) == 1
assert store.saved
scope, backend_run_id, saved_snapshot, saved_specs = store.saved[0]
assert scope.workflow_run_id == "workflow-run-1"
assert backend_run_id == "fake-run-1"
assert saved_snapshot is not None
assert client.request is not None
assert client.request.session_snapshot is existing_snapshot
# Persist enough composition shape to replay a cleanup run; plugin layers
# (which would carry credentials) are intentionally absent.
saved_layer_names = [spec.name for spec in saved_specs]
assert saved_layer_names, "cleanup specs must persist at least the non-plugin layers"
plugin_types = {"dify.plugin.llm", "dify.plugin.tools"}
assert not {spec.type for spec in saved_specs} & plugin_types
def test_agent_node_run_when_session_store_save_raises_records_persist_error_in_metadata():
"""A DB-side write failure must not crash the node; it should set
``session_snapshot_persist_error`` in the agent_backend metadata so the
incident is observable from the workflow_node_executions record."""
class _ExplodingSessionStore(FakeSessionStore):
def save_active_snapshot(self, **kwargs): # type: ignore[override]
del kwargs
raise RuntimeError("simulated DB failure")
store = _ExplodingSessionStore()
events = list(_node(session_store=store)._run())
assert len(events) == 1
result = cast(StreamCompletedEvent, events[0]).node_run_result
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
agent_backend = result.metadata[WorkflowNodeExecutionMetadataKey.AGENT_LOG]["agent_backend"]
assert agent_backend["session_snapshot_persisted"] is False
assert agent_backend["session_snapshot_persist_error"] == "workflow_agent_runtime_session_store_error"
def test_agent_node_failed_run_when_mark_cleaned_raises_records_cleanup_error_in_metadata():
"""Same defensive pattern: a DB-side mark_cleaned failure must surface as
a ``session_snapshot_cleanup_error`` in metadata, not as a node crash."""
class _ExplodingMarkCleanedStore(FakeSessionStore):
def mark_cleaned(self, **kwargs): # type: ignore[override]
del kwargs
raise RuntimeError("simulated DB failure")
store = _ExplodingMarkCleanedStore()
events = list(_node(scenario=FakeAgentBackendScenario.FAILED, session_store=store)._run())
assert len(events) == 1
result = cast(StreamCompletedEvent, events[0]).node_run_result
assert result.status == WorkflowNodeExecutionStatus.FAILED
agent_backend = result.metadata[WorkflowNodeExecutionMetadataKey.AGENT_LOG]["agent_backend"]
assert agent_backend["session_snapshot_cleaned_on_failure"] is False
assert agent_backend["session_snapshot_cleanup_error"] == "workflow_agent_runtime_session_store_error"
def test_agent_node_success_run_without_session_store_skips_persistence():
"""When ``session_store`` is None the node still completes successfully —
the lifecycle branch is a no-op and the run result is unaffected."""
events = list(_node(session_store=None)._run())
assert len(events) == 1
result = cast(StreamCompletedEvent, events[0]).node_run_result
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
agent_backend = result.metadata[WorkflowNodeExecutionMetadataKey.AGENT_LOG]["agent_backend"]
# No persistence metadata is attached when the store is missing.
assert "session_snapshot_persisted" not in agent_backend
def test_agent_node_failed_run_without_session_store_skips_mark_cleaned():
"""``session_store=None`` + failed terminal must remain a no-op for
the cleanup branch — the node failure path still surfaces correctly."""
events = list(_node(scenario=FakeAgentBackendScenario.FAILED, session_store=None)._run())
assert len(events) == 1
result = cast(StreamCompletedEvent, events[0]).node_run_result
assert result.status == WorkflowNodeExecutionStatus.FAILED
agent_backend = result.metadata[WorkflowNodeExecutionMetadataKey.AGENT_LOG]["agent_backend"]
assert "session_snapshot_cleaned_on_failure" not in agent_backend
def test_agent_node_paused_run_requests_workflow_pause_and_persists_snapshot():
store = FakeSessionStore()
node = _node(scenario=FakeAgentBackendScenario.PAUSED, session_store=store)
events = list(node._run())
assert len(events) == 1
assert isinstance(events[0], PauseRequestedEvent)
assert store.saved
assert store.saved[0][1] == "fake-run-1"
assert store.saved[0][3], "paused agent run should still persist replayable layer specs"
def test_agent_node_records_stream_usage_metadata():
metadata = {"agent_backend": {"run_id": "run-1"}}

View File

@ -1,14 +1,10 @@
from dataclasses import replace
from typing import cast
import pytest
from agenton.compositor import CompositorSessionSnapshot
from dify_agent.layers.dify_plugin import DifyPluginToolConfig, DifyPluginToolsLayerConfig
from dify_agent.protocol import DIFY_AGENT_HISTORY_LAYER_ID, DIFY_AGENT_MODEL_LAYER_ID
from clients.agent_backend import DIFY_EXECUTION_CONTEXT_LAYER_ID, DIFY_PLUGIN_TOOLS_LAYER_ID
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom
from core.workflow.nodes.agent_v2.plugin_tools_builder import WorkflowAgentPluginToolsBuilder
from core.workflow.nodes.agent_v2.runtime_request_builder import (
WorkflowAgentRuntimeBuildContext,
WorkflowAgentRuntimeRequestBuilder,
@ -31,17 +27,6 @@ class FakeCredentialsProvider:
return {"api_key": "secret-key"}
class CapturingCredentialsProvider:
def __init__(self) -> None:
self.provider_name: str | None = None
self.model_name: str | None = None
def fetch(self, provider_name: str, model_name: str) -> dict[str, object]:
self.provider_name = provider_name
self.model_name = model_name
return {"api_key": "secret-key"}
class FakePluginToolsBuilder:
def __init__(self) -> None:
# Capture the runtime invocation source so tests can assert it was
@ -151,31 +136,7 @@ def test_builds_create_run_request_from_agent_soul_and_node_job():
assert dumped["composition"]["layers"][1]["config"]["prefix"] == "Use the previous output."
assert "Previous result" in dumped["composition"]["layers"][2]["config"]["user"]
assert dumped["composition"]["layers"][-1]["config"]["json_schema"]["properties"]["summary"]["type"] == "string"
assert DIFY_AGENT_HISTORY_LAYER_ID in layers
assert result.redacted_request["composition"]["layers"][5]["config"]["credentials"] == "[REDACTED]"
def test_normalizes_langgenius_model_provider_for_agent_backend_transport():
context = _context()
context.snapshot.config_snapshot = AgentSoulConfig(
prompt={"system_prompt": "You are careful."},
model=AgentSoulModelConfig(
plugin_id="langgenius/openai/openai",
model_provider="langgenius/openai/openai",
model="gpt-test",
),
)
credentials_provider = CapturingCredentialsProvider()
result = WorkflowAgentRuntimeRequestBuilder(credentials_provider=credentials_provider).build(context)
dumped = result.request.model_dump(mode="json")
layers = {layer["name"]: layer for layer in dumped["composition"]["layers"]}
model_config = layers[DIFY_AGENT_MODEL_LAYER_ID]["config"]
assert credentials_provider.provider_name == "langgenius/openai/openai"
assert credentials_provider.model_name == "gpt-test"
assert model_config["plugin_id"] == "langgenius/openai"
assert model_config["model_provider"] == "openai"
assert result.redacted_request["composition"]["layers"][4]["config"]["credentials"] == "[REDACTED]"
def test_builds_workflow_run_request_with_file_output_schema_and_reserved_metadata():
@ -226,7 +187,7 @@ def test_builds_workflow_run_request_with_file_output_schema_and_reserved_metada
assert output_schema["properties"]["report"]["properties"]["file_id"]["type"] == "string"
assert output_schema["properties"]["confidence"]["type"] == "number"
assert output_schema["required"] == ["report"]
assert dumped["composition"]["layers"][5]["config"]["model_settings"] == {"temperature": 0.2}
assert dumped["composition"]["layers"][4]["config"]["model_settings"] == {"temperature": 0.2}
assert result.metadata["runtime_support"]["reserved_status"]["tools.dify_tools"] == "supported_when_config_valid"
assert result.metadata["runtime_support"]["reserved_status"]["tools.cli_tools"] == "reserved_not_executed"
warnings = result.metadata["runtime_support"]["unsupported_runtime_warnings"]
@ -263,7 +224,7 @@ def test_builds_workflow_run_request_with_dify_plugin_tools_layer():
plugin_tools_builder = FakePluginToolsBuilder()
result = WorkflowAgentRuntimeRequestBuilder(
credentials_provider=FakeCredentialsProvider(),
plugin_tools_builder=cast(WorkflowAgentPluginToolsBuilder, plugin_tools_builder),
plugin_tools_builder=plugin_tools_builder,
).build(context)
dumped = result.request.model_dump(mode="json")
@ -283,15 +244,6 @@ def test_builds_workflow_run_request_with_dify_plugin_tools_layer():
assert plugin_tools_builder.last_invoke_from == context.dify_context.invoke_from
def test_build_passes_saved_session_snapshot_to_agent_backend_request():
session_snapshot = CompositorSessionSnapshot(layers=[])
context = replace(_context(), session_snapshot=session_snapshot)
result = WorkflowAgentRuntimeRequestBuilder(credentials_provider=FakeCredentialsProvider()).build(context)
assert result.request.session_snapshot is session_snapshot
def test_requires_agent_soul_model_config():
context = _context()
snapshot = AgentConfigSnapshot(

View File

@ -1,412 +0,0 @@
from datetime import UTC
from typing import cast
import pytest
from agenton.compositor import CompositorSessionSnapshot
from agenton.compositor.schemas import LayerSessionSnapshot
from agenton.layers.base import LifecycleState
from dify_agent.protocol import CancelRunRequest, RunEvent, RunStatusResponse
from clients.agent_backend import AgentBackendRunRequestBuilder, CleanupLayerSpec, FakeAgentBackendRunClient
from clients.agent_backend.errors import AgentBackendHTTPError
from core.workflow.nodes.agent_v2.session_cleanup_layer import WorkflowAgentSessionCleanupLayer
from core.workflow.nodes.agent_v2.session_store import (
StoredWorkflowAgentSession,
WorkflowAgentRuntimeSessionStore,
WorkflowAgentSessionScope,
)
from core.workflow.system_variables import build_system_variables
from graphon.entities.pause_reason import SchedulingPause
from graphon.graph_engine.command_channels import CommandChannel
from graphon.graph_events import (
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
def _layer_snapshot(name: str) -> LayerSessionSnapshot:
return LayerSessionSnapshot(
name=name,
lifecycle_state=LifecycleState.SUSPENDED,
runtime_state={},
)
def _stored_session(scope: WorkflowAgentSessionScope, *, index: int = 1) -> StoredWorkflowAgentSession:
"""A typical stored session with prompt + execution_context + history + llm specs.
The LLM layer is *not* in ``composition_layer_specs`` because the cleanup
contract excludes credential-bearing plugin layers, but it *is* present in
the saved snapshot so the layer's filter logic gets exercised.
"""
return StoredWorkflowAgentSession(
scope=scope,
session_snapshot=CompositorSessionSnapshot(
layers=[
_layer_snapshot("workflow_node_job_prompt"),
_layer_snapshot("execution_context"),
_layer_snapshot("history"),
_layer_snapshot("llm"),
]
),
backend_run_id=f"agent-run-{index}",
composition_layer_specs=[
CleanupLayerSpec(name="workflow_node_job_prompt", type="plain.prompt", config={"prefix": "ok"}),
CleanupLayerSpec(name="execution_context", type="dify.execution_context", config={"tenant_id": "t"}),
CleanupLayerSpec(name="history", type="pydantic_ai.history"),
],
)
class FakeSessionStore:
"""In-memory stand-in for ``WorkflowAgentRuntimeSessionStore``."""
def __init__(self, *, stored: list[StoredWorkflowAgentSession] | None = None) -> None:
self._stored = stored if stored is not None else [_stored_session(_default_scope())]
self.list_calls: list[str] = []
self.cleaned: list[tuple[WorkflowAgentSessionScope, str | None]] = []
def list_active_sessions(self, *, workflow_run_id: str) -> list[StoredWorkflowAgentSession]:
self.list_calls.append(workflow_run_id)
return list(self._stored)
def mark_cleaned(self, *, scope: WorkflowAgentSessionScope, backend_run_id: str | None = None) -> None:
self.cleaned.append((scope, backend_run_id))
def _default_scope() -> WorkflowAgentSessionScope:
return WorkflowAgentSessionScope(
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
workflow_run_id="workflow-run-1",
node_id="agent-node",
node_execution_id="node-exec-1",
binding_id="binding-1",
agent_id="agent-1",
agent_config_snapshot_id="snapshot-1",
)
class _WaitableFakeAgentBackendRunClient(FakeAgentBackendRunClient):
"""``FakeAgentBackendRunClient`` plus the ``wait_run`` hook the layer needs."""
def __init__(
self,
*,
run_id: str = "cleanup-run-1",
wait_status: str = "succeeded",
wait_error: str | None = None,
wait_raises: Exception | None = None,
) -> None:
super().__init__(run_id=run_id)
self._wait_status = wait_status
self._wait_error = wait_error
self._wait_raises = wait_raises
self.wait_calls: list[tuple[str, float | None]] = []
def wait_run(self, run_id: str, *, timeout_seconds: float | None = None) -> RunStatusResponse:
self.wait_calls.append((run_id, timeout_seconds))
if self._wait_raises is not None:
raise self._wait_raises
from datetime import datetime
return RunStatusResponse(
run_id=run_id,
status=cast(object, self._wait_status), # protocol Literal; cast keeps tests flexible
created_at=datetime(2026, 1, 1, tzinfo=UTC),
updated_at=datetime(2026, 1, 1, tzinfo=UTC),
error=self._wait_error,
)
# Inherit ``create_run`` from FakeAgentBackendRunClient; the missing protocol
# methods below are stub-only because the cleanup layer never calls them.
def cancel_run(self, run_id: str, request: CancelRunRequest | None = None): # pragma: no cover
del run_id, request
raise NotImplementedError
def stream_events(self, run_id: str, *, after: str | None = None): # pragma: no cover
del run_id, after
if False:
yield cast(RunEvent, None)
def _build_layer(
*,
session_store: FakeSessionStore,
agent_backend_client: _WaitableFakeAgentBackendRunClient,
http_cleanup_supported: bool = True,
) -> WorkflowAgentSessionCleanupLayer:
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="workflow-run-1"),
user_inputs={},
conversation_variables=[],
)
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
layer = WorkflowAgentSessionCleanupLayer(
session_store=cast(WorkflowAgentRuntimeSessionStore, session_store),
request_builder=AgentBackendRunRequestBuilder(),
agent_backend_client=agent_backend_client,
)
# Tests opt in to the future HTTP-cleanup branch; the production default
# (False) is exercised by the dedicated tests below.
layer._HTTP_CLEANUP_SUPPORTED = http_cleanup_supported # type: ignore[reportPrivateUsage]
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(runtime_state), cast(CommandChannel, object()))
return layer
@pytest.mark.parametrize(
"terminal_event",
[
GraphRunSucceededEvent(outputs={}),
GraphRunPartialSucceededEvent(exceptions_count=1, outputs={}),
GraphRunFailedEvent(error="boom"),
GraphRunAbortedEvent(reason="user cancelled", outputs={}),
],
ids=["succeeded", "partial_succeeded", "failed", "aborted"],
)
def test_cleanup_layer_triggers_cleanup_only_run_on_each_terminal_event(terminal_event):
session_store = FakeSessionStore()
agent_backend_client = _WaitableFakeAgentBackendRunClient()
layer = _build_layer(session_store=session_store, agent_backend_client=agent_backend_client)
layer.on_event(terminal_event)
assert session_store.list_calls == ["workflow-run-1"]
assert agent_backend_client.request is not None
# Cleanup composition replays the persisted (non-plugin) layer specs so the
# agent backend's snapshot-vs-composition name match succeeds.
layer_names = [layer.name for layer in agent_backend_client.request.composition.layers]
assert layer_names == ["workflow_node_job_prompt", "execution_context", "history"]
assert agent_backend_client.request.on_exit.default.value == "delete"
assert agent_backend_client.request.metadata["agent_backend_lifecycle"] == "session_cleanup"
# Snapshot is filtered to drop the plugin layer entry so names match the
# cleanup composition.
assert agent_backend_client.request.session_snapshot is not None
snapshot_names = [layer.name for layer in agent_backend_client.request.session_snapshot.layers]
assert snapshot_names == ["workflow_node_job_prompt", "execution_context", "history"]
# The layer waited for terminal status and the run succeeded, so the row
# is marked CLEANED with the cleanup run id.
assert agent_backend_client.wait_calls
assert session_store.cleaned == [(_default_scope(), "cleanup-run-1")]
@pytest.mark.parametrize(
"non_terminal_event",
[
GraphRunStartedEvent(),
GraphRunPausedEvent(reasons=[SchedulingPause(message="awaiting human input")], outputs={}),
],
ids=["started", "paused"],
)
def test_cleanup_layer_ignores_non_terminal_events(non_terminal_event):
session_store = FakeSessionStore()
agent_backend_client = _WaitableFakeAgentBackendRunClient()
layer = _build_layer(session_store=session_store, agent_backend_client=agent_backend_client)
layer.on_event(non_terminal_event)
assert session_store.list_calls == []
assert agent_backend_client.request is None
assert session_store.cleaned == []
def test_cleanup_layer_does_not_mark_cleaned_when_cleanup_run_fails():
"""Trap D: cleanup-only run goes ``run_failed`` (e.g. snapshot validation
error) — the layer must leave the row ACTIVE so it can be retried instead
of silently leaking suspended agent-backend layers."""
session_store = FakeSessionStore()
agent_backend_client = _WaitableFakeAgentBackendRunClient(
wait_status="failed",
wait_error="snapshot mismatch",
)
layer = _build_layer(session_store=session_store, agent_backend_client=agent_backend_client)
layer.on_event(GraphRunSucceededEvent(outputs={}))
assert agent_backend_client.wait_calls
assert session_store.cleaned == []
def test_cleanup_layer_does_not_mark_cleaned_when_wait_raises():
session_store = FakeSessionStore()
agent_backend_client = _WaitableFakeAgentBackendRunClient(
wait_raises=AgentBackendHTTPError("boom", status_code=500, detail=None),
)
layer = _build_layer(session_store=session_store, agent_backend_client=agent_backend_client)
layer.on_event(GraphRunSucceededEvent(outputs={}))
assert session_store.cleaned == []
def test_cleanup_layer_marks_cleaned_locally_when_http_cleanup_disabled():
"""Production default: dify-agent has no cleanup-only run mode yet, so the
layer must retire the local row without issuing a doomed HTTP request that
would crash inside the agent backend's runner on the missing LLM layer."""
session_store = FakeSessionStore()
agent_backend_client = _WaitableFakeAgentBackendRunClient()
layer = _build_layer(
session_store=session_store,
agent_backend_client=agent_backend_client,
http_cleanup_supported=False,
)
layer.on_event(GraphRunSucceededEvent(outputs={}))
# No HTTP call goes out — the trap is avoided entirely.
assert agent_backend_client.request is None
assert agent_backend_client.wait_calls == []
# Local row is still retired so a workflow loop cannot resume from stale state.
assert session_store.cleaned == [(_default_scope(), "agent-run-1")]
def test_cleanup_layer_skips_sessions_without_persisted_specs():
"""Backwards-compatible safety net: a row written before A.1 landed has
no composition_layer_specs, so cleanup would unavoidably hit the snapshot-
validation trap. The layer must skip such rows instead of issuing a
doomed request."""
scope = _default_scope()
legacy_session = StoredWorkflowAgentSession(
scope=scope,
session_snapshot=CompositorSessionSnapshot(layers=[_layer_snapshot("history")]),
backend_run_id="legacy-run",
composition_layer_specs=[],
)
session_store = FakeSessionStore(stored=[legacy_session])
agent_backend_client = _WaitableFakeAgentBackendRunClient()
layer = _build_layer(session_store=session_store, agent_backend_client=agent_backend_client)
layer.on_event(GraphRunSucceededEvent(outputs={}))
assert agent_backend_client.request is None
assert session_store.cleaned == []
def test_cleanup_layer_fans_out_to_every_active_session():
scopes = [
WorkflowAgentSessionScope(
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
workflow_run_id="workflow-run-1",
node_id=f"agent-node-{i}",
node_execution_id=f"node-exec-{i}",
binding_id=f"binding-{i}",
agent_id=f"agent-{i}",
agent_config_snapshot_id=f"snapshot-{i}",
)
for i in range(3)
]
session_store = FakeSessionStore(stored=[_stored_session(scope, index=i) for i, scope in enumerate(scopes, 1)])
agent_backend_client = _WaitableFakeAgentBackendRunClient(run_id="cleanup-run-many")
layer = _build_layer(session_store=session_store, agent_backend_client=agent_backend_client)
layer.on_event(GraphRunSucceededEvent(outputs={}))
# One cleanup row per stored ACTIVE session, all marked cleaned with the
# backend run id returned by the agent backend client.
assert [entry[0] for entry in session_store.cleaned] == scopes
assert {entry[1] for entry in session_store.cleaned} == {"cleanup-run-many"}
def test_cleanup_layer_warns_when_http_enabled_but_client_missing(caplog):
"""The HTTP cleanup branch must defensively skip when no client was wired.
This is the deployment-misconfig path: ``_HTTP_CLEANUP_SUPPORTED`` was
flipped to ``True`` but ``AGENT_BACKEND_BASE_URL`` is unset, so the
factory returned ``None``. The layer must not crash and must not silently
retire the row — the warning surfaces the misconfig.
"""
import logging
session_store = FakeSessionStore()
layer = WorkflowAgentSessionCleanupLayer(
session_store=cast(WorkflowAgentRuntimeSessionStore, session_store),
request_builder=AgentBackendRunRequestBuilder(),
agent_backend_client=None,
)
layer._HTTP_CLEANUP_SUPPORTED = True # type: ignore[reportPrivateUsage]
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="workflow-run-1"),
user_inputs={},
conversation_variables=[],
)
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(runtime_state), cast(CommandChannel, object()))
with caplog.at_level(logging.WARNING):
layer.on_event(GraphRunSucceededEvent(outputs={}))
assert session_store.cleaned == []
assert any("no agent backend client is wired in" in record.message for record in caplog.records)
def test_cleanup_layer_skips_workflow_terminal_when_workflow_run_id_missing(caplog):
"""``workflow_run_id`` is the keying field; without it the fanout cannot
target a row, so the layer logs a warning and bails."""
import logging
session_store = FakeSessionStore()
agent_backend_client = _WaitableFakeAgentBackendRunClient()
layer = WorkflowAgentSessionCleanupLayer(
session_store=cast(WorkflowAgentRuntimeSessionStore, session_store),
request_builder=AgentBackendRunRequestBuilder(),
agent_backend_client=agent_backend_client,
)
# Bootstrap *without* a workflow_execution_id system variable.
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id=""),
user_inputs={},
conversation_variables=[],
)
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(runtime_state), cast(CommandChannel, object()))
with caplog.at_level(logging.WARNING):
layer.on_event(GraphRunSucceededEvent(outputs={}))
assert session_store.list_calls == []
assert session_store.cleaned == []
assert any("workflow_run_id is missing" in record.message for record in caplog.records)
def test_build_workflow_agent_session_cleanup_layer_returns_layer_without_client_when_unconfigured(
monkeypatch,
):
"""The production builder must pass ``None`` for the agent backend client
when neither AGENT_BACKEND_BASE_URL nor AGENT_BACKEND_USE_FAKE is set, so
that unit-test environments without backend config don't crash at runner
construction."""
from configs import dify_config
from core.workflow.nodes.agent_v2.session_cleanup_layer import (
build_workflow_agent_session_cleanup_layer,
)
monkeypatch.setattr(dify_config, "AGENT_BACKEND_BASE_URL", None, raising=False)
monkeypatch.setattr(dify_config, "AGENT_BACKEND_USE_FAKE", False, raising=False)
layer = build_workflow_agent_session_cleanup_layer()
assert layer._agent_backend_client is None # type: ignore[reportPrivateUsage]
def test_build_workflow_agent_session_cleanup_layer_returns_layer_with_fake_client(monkeypatch):
"""With ``AGENT_BACKEND_USE_FAKE`` enabled the helper wires in the
deterministic fake client without needing a base_url."""
from clients.agent_backend.fake_client import FakeAgentBackendRunClient
from configs import dify_config
from core.workflow.nodes.agent_v2.session_cleanup_layer import (
build_workflow_agent_session_cleanup_layer,
)
monkeypatch.setattr(dify_config, "AGENT_BACKEND_BASE_URL", None, raising=False)
monkeypatch.setattr(dify_config, "AGENT_BACKEND_USE_FAKE", True, raising=False)
monkeypatch.setattr(dify_config, "AGENT_BACKEND_FAKE_SCENARIO", "success", raising=False)
layer = build_workflow_agent_session_cleanup_layer()
assert isinstance(layer._agent_backend_client, FakeAgentBackendRunClient) # type: ignore[reportPrivateUsage]

View File

@ -1,286 +0,0 @@
"""Unit tests for :mod:`core.workflow.nodes.agent_v2.session_store`.
Uses the in-memory SQLite engine configured by the project conftest plus a
per-test ``CREATE TABLE`` so the real ORM round-trip exercises every store
method. Keeps the suite self-contained — no Postgres / Docker required — while
still hitting the actual ``session_factory`` code path that production uses.
"""
from __future__ import annotations
from collections.abc import Generator
import pytest
from agenton.compositor import CompositorSessionSnapshot
from agenton.compositor.schemas import LayerSessionSnapshot
from agenton.layers.base import LifecycleState
from sqlalchemy import delete
from clients.agent_backend.request_builder import CleanupLayerSpec
from core.db.session_factory import session_factory
from core.workflow.nodes.agent_v2.session_store import (
StoredWorkflowAgentSession,
WorkflowAgentRuntimeSessionStore,
WorkflowAgentSessionScope,
)
from models.agent import WorkflowAgentRuntimeSession, WorkflowAgentRuntimeSessionStatus
def _scope(workflow_run_id: str | None = "wfr-1", binding_id: str = "binding-1") -> WorkflowAgentSessionScope:
return WorkflowAgentSessionScope(
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
workflow_run_id=workflow_run_id,
node_id="agent-node",
node_execution_id="node-exec-1",
binding_id=binding_id,
agent_id="agent-1",
agent_config_snapshot_id="snapshot-1",
)
def _snapshot(messages: int = 1) -> CompositorSessionSnapshot:
return CompositorSessionSnapshot(
layers=[
LayerSessionSnapshot(
name="history",
lifecycle_state=LifecycleState.SUSPENDED,
runtime_state={"messages": [{"role": "user", "content": f"m{i}"} for i in range(messages)]},
)
]
)
def _specs() -> list[CleanupLayerSpec]:
return [
CleanupLayerSpec(name="workflow_node_job_prompt", type="plain.prompt", config={"prefix": "ok"}),
CleanupLayerSpec(name="history", type="pydantic_ai.history"),
]
@pytest.fixture(autouse=True)
def _create_table() -> Generator[None, None, None]:
"""Create the lifecycle table on the in-memory SQLite engine, drop after."""
engine = session_factory.get_session_maker().kw["bind"]
WorkflowAgentRuntimeSession.__table__.create(bind=engine, checkfirst=True)
yield
with session_factory.create_session() as session:
session.execute(delete(WorkflowAgentRuntimeSession))
session.commit()
WorkflowAgentRuntimeSession.__table__.drop(bind=engine, checkfirst=True)
def test_load_active_snapshot_returns_none_when_scope_has_no_workflow_run_id():
"""``workflow_run_id`` is the keying column; no row can match without it."""
store = WorkflowAgentRuntimeSessionStore()
assert store.load_active_snapshot(_scope(workflow_run_id=None)) is None
def test_load_active_snapshot_returns_none_when_no_row_matches():
store = WorkflowAgentRuntimeSessionStore()
assert store.load_active_snapshot(_scope()) is None
def test_save_active_snapshot_creates_row_and_load_round_trips():
store = WorkflowAgentRuntimeSessionStore()
snapshot = _snapshot(messages=2)
store.save_active_snapshot(
scope=_scope(), backend_run_id="run-1", snapshot=snapshot, composition_layer_specs=_specs()
)
loaded = store.load_active_snapshot(_scope())
assert loaded is not None
assert len(loaded.layers) == 1
assert loaded.layers[0].name == "history"
assert loaded.layers[0].runtime_state["messages"] == snapshot.layers[0].runtime_state["messages"]
def test_save_active_snapshot_skips_when_workflow_run_id_missing():
"""Without a workflow_run_id the row cannot be keyed; save is a no-op."""
store = WorkflowAgentRuntimeSessionStore()
store.save_active_snapshot(
scope=_scope(workflow_run_id=None),
backend_run_id="run-skipped",
snapshot=_snapshot(),
composition_layer_specs=_specs(),
)
with session_factory.create_session() as session:
assert session.query(WorkflowAgentRuntimeSession).count() == 0
def test_save_active_snapshot_skips_when_snapshot_missing():
"""A run that produced no snapshot (e.g. failed agent run) does not write."""
store = WorkflowAgentRuntimeSessionStore()
store.save_active_snapshot(
scope=_scope(),
backend_run_id="run-empty",
snapshot=None,
composition_layer_specs=_specs(),
)
with session_factory.create_session() as session:
assert session.query(WorkflowAgentRuntimeSession).count() == 0
def test_save_active_snapshot_updates_existing_row_on_re_entry():
"""A second save under the same scope must update in place, not insert."""
store = WorkflowAgentRuntimeSessionStore()
store.save_active_snapshot(
scope=_scope(),
backend_run_id="run-1",
snapshot=_snapshot(messages=1),
composition_layer_specs=_specs(),
)
# Second call with new snapshot + backend_run_id.
store.save_active_snapshot(
scope=_scope(),
backend_run_id="run-2",
snapshot=_snapshot(messages=2),
composition_layer_specs=_specs(),
)
with session_factory.create_session() as session:
rows = session.query(WorkflowAgentRuntimeSession).all()
assert len(rows) == 1
assert rows[0].backend_run_id == "run-2"
assert rows[0].status == WorkflowAgentRuntimeSessionStatus.ACTIVE
assert rows[0].cleaned_at is None
def test_save_active_snapshot_resurrects_cleaned_row():
"""If a prior cleanup retired the row, a re-entry flips it back to ACTIVE."""
store = WorkflowAgentRuntimeSessionStore()
store.save_active_snapshot(
scope=_scope(),
backend_run_id="run-1",
snapshot=_snapshot(),
composition_layer_specs=_specs(),
)
store.mark_cleaned(scope=_scope(), backend_run_id="cleanup-1")
# Save again — the existing row was CLEANED; should be revived.
store.save_active_snapshot(
scope=_scope(),
backend_run_id="run-2",
snapshot=_snapshot(messages=3),
composition_layer_specs=_specs(),
)
with session_factory.create_session() as session:
rows = session.query(WorkflowAgentRuntimeSession).all()
assert len(rows) == 1
assert rows[0].status == WorkflowAgentRuntimeSessionStatus.ACTIVE
assert rows[0].cleaned_at is None
assert rows[0].backend_run_id == "run-2"
def test_list_active_sessions_returns_specs_and_snapshot():
store = WorkflowAgentRuntimeSessionStore()
store.save_active_snapshot(
scope=_scope(binding_id="binding-A"),
backend_run_id="run-A",
snapshot=_snapshot(),
composition_layer_specs=_specs(),
)
store.save_active_snapshot(
scope=_scope(binding_id="binding-B"),
backend_run_id="run-B",
snapshot=_snapshot(messages=2),
composition_layer_specs=_specs(),
)
listed = store.list_active_sessions(workflow_run_id="wfr-1")
assert {s.backend_run_id for s in listed} == {"run-A", "run-B"}
by_run = {s.backend_run_id: s for s in listed}
assert isinstance(by_run["run-A"], StoredWorkflowAgentSession)
# Specs round-trip through pydantic TypeAdapter — ensure deserialize works.
assert by_run["run-A"].composition_layer_specs[0].name == "workflow_node_job_prompt"
assert by_run["run-A"].composition_layer_specs[1].type == "pydantic_ai.history"
# node_execution_id default-replaces NULL with "" when the DB column is None.
assert by_run["run-A"].scope.node_execution_id == "node-exec-1"
def test_list_active_sessions_skips_cleaned_rows():
store = WorkflowAgentRuntimeSessionStore()
store.save_active_snapshot(
scope=_scope(binding_id="binding-A"),
backend_run_id="run-A",
snapshot=_snapshot(),
composition_layer_specs=_specs(),
)
store.save_active_snapshot(
scope=_scope(binding_id="binding-B"),
backend_run_id="run-B",
snapshot=_snapshot(),
composition_layer_specs=_specs(),
)
store.mark_cleaned(scope=_scope(binding_id="binding-A"), backend_run_id="cleanup-A")
listed = store.list_active_sessions(workflow_run_id="wfr-1")
assert {s.backend_run_id for s in listed} == {"run-B"}
def test_list_active_sessions_handles_legacy_rows_without_specs():
"""Rows persisted before composition_layer_specs landed have an empty string."""
# Insert a legacy-shape row directly: empty specs payload simulates a row
# written before the spec persistence feature landed in A.1.
store = WorkflowAgentRuntimeSessionStore()
store.save_active_snapshot(
scope=_scope(),
backend_run_id="run-legacy",
snapshot=_snapshot(),
composition_layer_specs=[],
)
listed = store.list_active_sessions(workflow_run_id="wfr-1")
assert len(listed) == 1
assert listed[0].composition_layer_specs == []
def test_mark_cleaned_sets_status_and_cleaned_at_with_backend_run_id():
store = WorkflowAgentRuntimeSessionStore()
store.save_active_snapshot(
scope=_scope(),
backend_run_id="run-1",
snapshot=_snapshot(),
composition_layer_specs=_specs(),
)
store.mark_cleaned(scope=_scope(), backend_run_id="cleanup-1")
with session_factory.create_session() as session:
row = session.query(WorkflowAgentRuntimeSession).one()
assert row.status == WorkflowAgentRuntimeSessionStatus.CLEANED
assert row.cleaned_at is not None
assert row.backend_run_id == "cleanup-1"
def test_mark_cleaned_preserves_existing_backend_run_id_when_none_given():
"""``backend_run_id=None`` means "leave the previous one in place"."""
store = WorkflowAgentRuntimeSessionStore()
store.save_active_snapshot(
scope=_scope(),
backend_run_id="run-1",
snapshot=_snapshot(),
composition_layer_specs=_specs(),
)
store.mark_cleaned(scope=_scope(), backend_run_id=None)
with session_factory.create_session() as session:
row = session.query(WorkflowAgentRuntimeSession).one()
assert row.status == WorkflowAgentRuntimeSessionStatus.CLEANED
assert row.backend_run_id == "run-1"
def test_mark_cleaned_is_a_noop_when_no_active_row():
"""No matching ACTIVE row → no-op (already-cleaned rows are not re-touched)."""
store = WorkflowAgentRuntimeSessionStore()
store.mark_cleaned(scope=_scope(), backend_run_id="cleanup-1")
with session_factory.create_session() as session:
assert session.query(WorkflowAgentRuntimeSession).count() == 0
def test_mark_cleaned_is_a_noop_when_workflow_run_id_missing():
"""Without a workflow_run_id we cannot key the row; ignore the call."""
store = WorkflowAgentRuntimeSessionStore()
store.mark_cleaned(scope=_scope(workflow_run_id=None), backend_run_id="cleanup-1")
# Sanity — no rows created or touched.
with session_factory.create_session() as session:
assert session.query(WorkflowAgentRuntimeSession).count() == 0

View File

@ -11,7 +11,6 @@ from libs.oauth_bearer import (
SubjectType,
TokenKind,
TokenKindRegistry,
TokenType,
)
@ -22,7 +21,7 @@ def _registry_with_resolver(resolver) -> TokenKindRegistry:
prefix="dfoa_",
subject_type=SubjectType.ACCOUNT,
scopes=frozenset({Scope.FULL}),
token_type=TokenType.OAUTH_ACCOUNT,
source="oauth_account",
resolver=resolver,
)
]
@ -64,7 +63,7 @@ def test_unknown_prefix_raises_generic_invalid_bearer():
prefix="dfoa_",
subject_type=SubjectType.ACCOUNT,
scopes=frozenset({Scope.FULL}),
token_type=TokenType.OAUTH_ACCOUNT,
source="oauth_account",
resolver=MagicMock(),
)
]

View File

@ -19,7 +19,6 @@ from libs.oauth_bearer import (
AuthContext,
Scope,
SubjectType,
TokenType,
require_scope,
reset_auth_ctx,
set_auth_ctx,
@ -51,7 +50,7 @@ def _ctx(scopes) -> AuthContext:
client_id="difyctl",
scopes=scopes,
token_id=uuid.uuid4(),
token_type=TokenType.OAUTH_ACCOUNT,
source="oauth_account",
expires_at=None,
token_hash="h1",
verified_tenants={},

View File

@ -8,7 +8,7 @@ from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, require_workspace_member
from libs.oauth_bearer import AuthContext, Scope, SubjectType, require_workspace_member
def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> AuthContext:
@ -20,7 +20,7 @@ def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> Au
client_id="difyctl",
scopes=frozenset({Scope.FULL}),
token_id=uuid.uuid4(),
token_type=TokenType.OAUTH_ACCOUNT if account else TokenType.OAUTH_EXTERNAL_SSO,
source="oauth_account",
expires_at=None,
token_hash="h1",
verified_tenants=dict(verified or {}),

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import uuid
from unittest.mock import MagicMock
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT, AuthContext, SubjectType, TokenType
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT, AuthContext, SubjectType
from services.oauth_device_flow import (
list_active_sessions,
revoke_oauth_token,
@ -21,7 +21,7 @@ def _account_ctx() -> AuthContext:
client_id="difyctl",
scopes=frozenset({"full"}),
token_id=uuid.uuid4(),
token_type=TokenType.OAUTH_ACCOUNT,
source="oauth_account",
expires_at=None,
token_hash="h1",
verified_tenants={},
@ -37,7 +37,7 @@ def _sso_ctx() -> AuthContext:
client_id="difyctl",
scopes=frozenset({"apps:run"}),
token_id=uuid.uuid4(),
token_type=TokenType.OAUTH_EXTERNAL_SSO,
source="oauth_external_sso",
expires_at=None,
token_hash="h1",
verified_tenants={},

View File

@ -0,0 +1,115 @@
"""
Unit tests for delete_account_task.
Covers:
- Billing enabled with existing account: calls billing and sends success email
- Billing disabled with existing account: skips billing, sends success email
- Account not found: still calls billing when enabled, does not send email
- Billing deletion raises: logs and re-raises, no email
"""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from tasks.delete_account_task import delete_account_task
@pytest.fixture
def mock_db_session():
"""Mock session via session_factory.create_session()."""
with patch("tasks.delete_account_task.session_factory") as mock_sf:
session = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
cm.__exit__.return_value = None
mock_sf.create_session.return_value = cm
yield session
@pytest.fixture
def mock_deps():
"""Patch external dependencies: BillingService and send_deletion_success_task."""
with (
patch("tasks.delete_account_task.BillingService") as mock_billing,
patch("tasks.delete_account_task.send_deletion_success_task") as mock_mail_task,
):
# ensure .delay exists on the mail task
mock_mail_task.delay = MagicMock()
yield {
"billing": mock_billing,
"mail_task": mock_mail_task,
}
def _set_account_found(mock_db_session, email: str = "user@example.com"):
account = SimpleNamespace(email=email)
mock_db_session.scalar.return_value = account
return account
def _set_account_missing(mock_db_session):
mock_db_session.scalar.return_value = None
class TestDeleteAccountTask:
def test_billing_enabled_account_exists_calls_billing_and_sends_email(self, mock_db_session, mock_deps):
# Arrange
account_id = "acc-123"
account = _set_account_found(mock_db_session, email="a@b.com")
# Enable billing
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True):
# Act
delete_account_task(account_id)
# Assert
mock_deps["billing"].delete_account.assert_called_once_with(account_id)
mock_deps["mail_task"].delay.assert_called_once_with(account.email)
def test_billing_disabled_account_exists_sends_email_only(self, mock_db_session, mock_deps):
# Arrange
account_id = "acc-456"
account = _set_account_found(mock_db_session, email="x@y.com")
# Disable billing
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", False):
# Act
delete_account_task(account_id)
# Assert
mock_deps["billing"].delete_account.assert_not_called()
mock_deps["mail_task"].delay.assert_called_once_with(account.email)
def test_account_not_found_billing_enabled_calls_billing_no_email(self, mock_db_session, mock_deps, caplog):
# Arrange
account_id = "missing-id"
_set_account_missing(mock_db_session)
# Enable billing
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True):
# Act
delete_account_task(account_id)
# Assert
mock_deps["billing"].delete_account.assert_called_once_with(account_id)
mock_deps["mail_task"].delay.assert_not_called()
# Optional: verify log contains not found message
assert any("not found" in rec.getMessage().lower() for rec in caplog.records)
def test_billing_delete_raises_propagates_and_no_email(self, mock_db_session, mock_deps):
# Arrange
account_id = "acc-err"
_set_account_found(mock_db_session, email="err@ex.com")
mock_deps["billing"].delete_account.side_effect = RuntimeError("billing down")
# Enable billing
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True):
# Act & Assert
with pytest.raises(RuntimeError):
delete_account_task(account_id)
# Ensure email was not sent
mock_deps["mail_task"].delay.assert_not_called()

4
api/uv.lock generated
View File

@ -1281,7 +1281,7 @@ wheels = [
[[package]]
name = "dify-agent"
version = "0.1.0"
source = { editable = "../dify-agent" }
source = { directory = "../dify-agent" }
dependencies = [
{ name = "httpx" },
{ name = "pydantic" },
@ -1615,7 +1615,7 @@ requires-dist = [
{ name = "boto3", specifier = ">=1.43.14,<2.0.0" },
{ name = "celery", specifier = ">=5.6.3,<6.0.0" },
{ name = "croniter", specifier = ">=6.2.2,<7.0.0" },
{ name = "dify-agent", editable = "../dify-agent" },
{ name = "dify-agent", directory = "../dify-agent" },
{ name = "fastopenapi", extras = ["flask"], specifier = "==0.7.0" },
{ name = "flask", specifier = ">=3.1.3,<4.0.0" },
{ name = "flask-compress", specifier = ">=1.24,<2.0.0" },

View File

@ -147,11 +147,6 @@
"count": 1
}
},
"web/app/(commonLayout)/snippets/[snippetId]/page.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/(humanInputLayout)/form/[token]/form.tsx": {
"react/set-state-in-effect": {
"count": 1
@ -248,11 +243,6 @@
"count": 1
}
},
"web/app/components/app-sidebar/nav-link/index.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 3
}
},
"web/app/components/app/annotation/add-annotation-modal/edit-item/index.tsx": {
"erasable-syntax-only/enums": {
"count": 1
@ -3172,16 +3162,6 @@
"count": 2
}
},
"web/app/components/snippets/hooks/use-nodes-sync-draft.ts": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/snippets/hooks/use-snippet-run.ts": {
"no-restricted-imports": {
"count": 2
}
},
"web/app/components/tools/edit-custom-collection-modal/get-schema.tsx": {
"no-restricted-imports": {
"count": 1
@ -3352,11 +3332,6 @@
"count": 1
}
},
"web/app/components/workflow/block-selector/blocks.tsx": {
"unused-imports/no-unused-imports": {
"count": 1
}
},
"web/app/components/workflow/block-selector/hooks.ts": {
"react/set-state-in-effect": {
"count": 1
@ -5240,11 +5215,6 @@
"count": 1
}
},
"web/service/__tests__/use-snippet-workflows.spec.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/service/access-control.ts": {
"@tanstack/query/exhaustive-deps": {
"count": 1
@ -5510,11 +5480,6 @@
"count": 3
}
},
"web/service/use-snippet-workflows.ts": {
"no-restricted-imports": {
"count": 1
}
},
"web/service/use-tools.ts": {
"no-restricted-imports": {
"count": 1

View File

@ -32,7 +32,6 @@ import { Dialog, DialogContent, DialogTrigger } from '@langgenius/dify-ui/dialog
import { Drawer, DrawerPopup, DrawerTrigger } from '@langgenius/dify-ui/drawer'
import { FieldControl, FieldLabel, FieldRoot } from '@langgenius/dify-ui/field'
import { Form } from '@langgenius/dify-ui/form'
import { Kbd, KbdGroup } from '@langgenius/dify-ui/kbd'
import { Popover, PopoverContent, PopoverTrigger } from '@langgenius/dify-ui/popover'
import { SegmentedControl, SegmentedControlItem } from '@langgenius/dify-ui/segmented-control'
import { Textarea } from '@langgenius/dify-ui/textarea'
@ -47,7 +46,6 @@ Importing from `@langgenius/dify-ui` (no subpath) is intentionally not supported
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------ |
| Actions | `./button` | Design-system CTA primitive with `cva` variants. |
| Controls | `./segmented-control` | SegmentedControl for mode, filter, and view selection. |
| Display | `./kbd` | Keyboard input and shortcut keycap primitives. |
| Feedback | `./meter`, `./toast` | Meter is inline status; Toast owns the `z-60` layer. |
| Form | `./form`, `./field`, `./fieldset`, `./input`, `./textarea`, `./checkbox`, `./checkbox-group`, `./radio`, `./radio-group`, `./number-field`, `./select`, `./slider`, `./switch` | Native form boundary, field semantics, and controls. |
| Layout | `./scroll-area` | Custom-styled scrollbar over the host viewport. |

View File

@ -69,10 +69,6 @@
"types": "./src/input/index.tsx",
"import": "./src/input/index.tsx"
},
"./kbd": {
"types": "./src/kbd/index.tsx",
"import": "./src/kbd/index.tsx"
},
"./meter": {
"types": "./src/meter/index.tsx",
"import": "./src/meter/index.tsx"
@ -175,7 +171,6 @@
"@storybook/addon-themes": "catalog:",
"@storybook/react-vite": "catalog:",
"@tailwindcss/vite": "catalog:",
"@tanstack/react-hotkeys": "catalog:",
"@tanstack/react-virtual": "catalog:",
"@types/react": "catalog:",
"@types/react-dom": "catalog:",

View File

@ -23,7 +23,6 @@ import {
useAutocompleteFilteredItems,
} from '.'
import { cn } from '../cn'
import { Kbd } from '../kbd'
type Suggestion = {
value: string
@ -310,9 +309,9 @@ const CommandPaletteList = () => {
<span className="block truncate system-xs-regular text-text-tertiary">{item.description}</span>
</span>
</span>
<Kbd className="text-text-quaternary">
<kbd className="rounded-md border border-divider-subtle bg-components-badge-bg-dimm px-1.5 py-0.5 text-text-quaternary system-2xs-medium">
Enter
</Kbd>
</kbd>
</AutocompleteItem>
)}
</AutocompleteCollection>

View File

@ -1,59 +0,0 @@
import { render } from 'vitest-browser-react'
import { Kbd, KbdGroup } from '../index'
describe('Kbd', () => {
it('renders a native kbd element with the default gray variant', async () => {
const screen = await render(<Kbd></Kbd>)
const key = screen.getByText('⌘').element()
expect(key.tagName).toBe('KBD')
await expect.element(screen.getByText('⌘')).toHaveClass(
'h-4',
'min-w-4',
'px-px',
'rounded-sm',
'bg-components-kbd-bg-gray',
'text-text-tertiary',
'system-kbd',
)
})
it('applies the white variant for elevated or inverse surfaces', async () => {
const screen = await render(<Kbd color="white"></Kbd>)
await expect.element(screen.getByText('↵')).toHaveClass(
'bg-components-kbd-bg-white',
'text-text-primary-on-surface',
)
})
it('marks disabled keycaps visually without adding widget semantics', async () => {
const screen = await render(<Kbd disabled></Kbd>)
await expect.element(screen.getByText('⌘')).toHaveAttribute('data-disabled')
await expect.element(screen.getByText('⌘')).toHaveClass('opacity-30')
await expect.element(screen.getByText('⌘')).not.toHaveAttribute('aria-disabled')
})
it('merges custom classes with the design-system recipe', async () => {
const screen = await render(<Kbd className="custom-key h-5">K</Kbd>)
await expect.element(screen.getByText('K')).toHaveClass('custom-key', 'h-5')
})
})
describe('KbdGroup', () => {
it('groups keycaps without replacing individual kbd semantics', async () => {
const screen = await render(
<KbdGroup aria-label="Command Shift K">
<Kbd></Kbd>
<Kbd></Kbd>
<Kbd>K</Kbd>
</KbdGroup>,
)
const group = screen.getByLabelText('Command Shift K').element()
expect(group.tagName).toBe('SPAN')
expect(group.querySelectorAll('kbd')).toHaveLength(3)
})
})

View File

@ -1,230 +0,0 @@
import type { Meta, StoryObj } from '@storybook/react-vite'
import type { FormatDisplayOptions, RegisterableHotkey } from '@tanstack/react-hotkeys'
import { formatForDisplay } from '@tanstack/react-hotkeys'
import { Kbd, KbdGroup } from '.'
import {
ContextMenu,
ContextMenuContent,
ContextMenuItem,
ContextMenuSeparator,
ContextMenuTrigger,
} from '../context-menu'
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from '../tooltip'
const meta = {
title: 'Base/UI/Kbd',
component: Kbd,
parameters: {
layout: 'centered',
docs: {
description: {
component:
'Keyboard input primitives aligned with the Dify Key Set design. '
+ '`Kbd` renders a native `<kbd>` element for a single key or key-like token. '
+ '`KbdGroup` only groups multiple keycaps; it does not replace the individual `<kbd>` semantics.',
},
},
},
tags: ['autodocs'],
argTypes: {
color: {
control: 'select',
options: ['gray', 'white'],
},
disabled: { control: 'boolean' },
},
args: {
children: 'K',
color: 'gray',
},
} satisfies Meta<typeof Kbd>
export default meta
type Story = StoryObj<typeof meta>
const displayKeys = (
hotkey: RegisterableHotkey | (string & {}),
platform: FormatDisplayOptions['platform'] = 'mac',
) => {
if (typeof hotkey !== 'string')
return [formatForDisplay(hotkey, { platform })]
return hotkey
.split('+')
.filter(Boolean)
.map(key => formatForDisplay(key, { platform }))
}
const HotkeyKbdGroup = ({
hotkey,
color = 'gray',
platform = 'mac',
}: {
hotkey: RegisterableHotkey | (string & {})
color?: 'gray' | 'white'
platform?: FormatDisplayOptions['platform']
}) => (
<KbdGroup>
{displayKeys(hotkey, platform).map((key, index) => (
<Kbd key={`${key}-${index}`} color={color}>
{key}
</Kbd>
))}
</KbdGroup>
)
export const Default: Story = {
render: () => <HotkeyKbdGroup hotkey="Mod+K" />,
}
export const KeySet: Story = {
parameters: {
docs: {
description: {
story: 'Figma Key Set variants: gray and white, each with a disabled state. Disabled is visual only because `<kbd>` is not an interactive widget.',
},
},
},
render: () => (
<div className="grid grid-cols-[auto_auto_auto] items-center gap-x-4 gap-y-3 rounded-xl bg-components-panel-bg p-5">
<span className="system-xs-medium text-text-tertiary">Gray</span>
<KbdGroup>
<Kbd></Kbd>
<Kbd></Kbd>
</KbdGroup>
<KbdGroup>
<Kbd disabled></Kbd>
<Kbd disabled></Kbd>
</KbdGroup>
<span className="system-xs-medium text-text-tertiary">White</span>
<div className="rounded-lg bg-gray-900 p-2">
<KbdGroup>
<Kbd color="white"></Kbd>
<Kbd color="white"></Kbd>
</KbdGroup>
</div>
<div className="rounded-lg bg-gray-900 p-2">
<KbdGroup>
<Kbd color="white" disabled></Kbd>
<Kbd color="white" disabled></Kbd>
</KbdGroup>
</div>
</div>
),
}
export const FormattedShortcuts: Story = {
parameters: {
docs: {
description: {
story: '`Kbd` does not parse hotkeys. Compose it with a formatter at the feature layer; this story uses TanStack Hotkeys `formatForDisplay` for platform-aware labels.',
},
},
},
render: () => (
<div className="grid grid-cols-[auto_auto_auto] items-center gap-x-5 gap-y-3 rounded-xl bg-components-panel-bg p-5">
<span className="system-xs-medium text-text-tertiary">Action</span>
<span className="system-xs-medium text-text-tertiary">macOS</span>
<span className="system-xs-medium text-text-tertiary">Windows</span>
<span className="system-sm-regular text-text-secondary">Search</span>
<HotkeyKbdGroup hotkey="Mod+K" platform="mac" />
<HotkeyKbdGroup hotkey="Mod+K" platform="windows" />
<span className="system-sm-regular text-text-secondary">Save</span>
<HotkeyKbdGroup hotkey="Mod+S" platform="mac" />
<HotkeyKbdGroup hotkey="Mod+S" platform="windows" />
<span className="system-sm-regular text-text-secondary">Redo</span>
<HotkeyKbdGroup hotkey="Mod+Shift+Z" platform="mac" />
<HotkeyKbdGroup hotkey="Mod+Shift+Z" platform="windows" />
</div>
),
}
export const InTooltip: Story = {
decorators: [
Story => (
<TooltipProvider delay={0}>
<Story />
</TooltipProvider>
),
],
parameters: {
docs: {
description: {
story: 'Shortcut keycaps can be composed inside short tooltip content. The trigger keeps its own accessible name; the tooltip is only a visual hint.',
},
},
},
render: () => (
<Tooltip open>
<TooltipTrigger
render={(
<button
type="button"
aria-label="Collapse sidebar"
className="inline-flex size-8 items-center justify-center rounded-lg border border-divider-subtle bg-components-button-secondary-bg text-text-secondary shadow-xs"
>
<span aria-hidden className="i-ri-sidebar-fold-line size-4" />
</button>
)}
/>
<TooltipContent className="flex items-center gap-1">
<span>Collapse sidebar</span>
<HotkeyKbdGroup hotkey="Mod+B" />
</TooltipContent>
</Tooltip>
),
}
const MENU_ITEMS = [
{ label: 'Copy', icon: 'i-ri-file-copy-line', hotkey: 'Mod+C' },
{ label: 'Duplicate', icon: 'i-ri-stack-line', hotkey: 'Mod+D' },
{ label: 'Paste', icon: 'i-ri-clipboard-line', hotkey: 'Mod+V' },
] as const
export const InContextMenu: Story = {
parameters: {
docs: {
description: {
story: 'A compact context-menu composition based on the Dify Design Kit context menu example. The menu is intentionally small here because the story focuses on shortcut keycaps.',
},
},
},
render: () => (
<ContextMenu>
<ContextMenuTrigger
render={(
<button
type="button"
className="flex h-28 w-60 items-center justify-center rounded-xl border border-divider-subtle bg-background-default-subtle px-6 text-center system-sm-regular text-text-tertiary"
/>
)}
>
Context menu trigger
</ContextMenuTrigger>
<ContextMenuContent popupClassName="w-60">
{MENU_ITEMS.map(({ label, icon, hotkey }) => (
<ContextMenuItem key={label} className="justify-between gap-4">
<span aria-hidden className={`${icon} size-4 shrink-0 text-text-tertiary`} />
<span className="min-w-0 flex-1 truncate">{label}</span>
<HotkeyKbdGroup hotkey={hotkey} />
</ContextMenuItem>
))}
<ContextMenuSeparator />
<ContextMenuItem variant="destructive" className="justify-between gap-4">
<span aria-hidden className="i-ri-delete-bin-line size-4 shrink-0" />
<span className="min-w-0 flex-1 truncate">Delete</span>
<HotkeyKbdGroup hotkey="Delete" />
</ContextMenuItem>
</ContextMenuContent>
</ContextMenu>
),
}

Some files were not shown because too many files have changed in this diff Show More