mirror of
https://github.com/langgenius/dify.git
synced 2026-05-29 05:07:55 +08:00
Compare commits
44 Commits
dependabot
...
codex/migr
| Author | SHA1 | Date | |
|---|---|---|---|
| 49b638a099 | |||
| 99d9a6f6a2 | |||
| 83f6e7daf9 | |||
| e8de10a3b5 | |||
| f5ab5e7eb3 | |||
| 0c40e1c2a0 | |||
| c29d76757e | |||
| 91c1d3ad81 | |||
| 57b02e341c | |||
| b94ff65e9f | |||
| 678260e34e | |||
| 739e34d08a | |||
| 825fb9cb89 | |||
| 0e1f19a380 | |||
| 332d1ea533 | |||
| 9cdeffd0b1 | |||
| 09ef785a20 | |||
| d2788d7aba | |||
| cee90a4e82 | |||
| b2710b875b | |||
| 6464255d33 | |||
| 50face5760 | |||
| b034449a0c | |||
| a8d380bcaf | |||
| bee21c9f86 | |||
| cab215e209 | |||
| 7ae4ca9a60 | |||
| d342ff1a1e | |||
| 4384d8910e | |||
| fc773b9f57 | |||
| 6e1e0d9439 | |||
| 5c5a6e83e5 | |||
| dade318f00 | |||
| ebff9a3639 | |||
| 58b8fc21d4 | |||
| e0ad088657 | |||
| 323b2b82e0 | |||
| 7d45335a32 | |||
| f5d664887b | |||
| 5aa24c25d9 | |||
| eed8d659d1 | |||
| 59e99ee1ae | |||
| 533929d314 | |||
| fb07b43107 |
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@ -166,6 +166,7 @@
|
||||
|
||||
# Frontend - App - API Documentation
|
||||
/web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/develop/template/*.mdx @JzoNgKVO @iamjoel @RiskeyL
|
||||
|
||||
# Frontend - App - Logs and Annotations
|
||||
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||
|
||||
10
.github/workflows/cli-tests.yml
vendored
10
.github/workflows/cli-tests.yml
vendored
@ -15,8 +15,12 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: CLI Tests
|
||||
runs-on: depot-ubuntu-24.04
|
||||
name: CLI Tests (${{ matrix.os }})
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [depot-ubuntu-24.04, windows-latest, macos-latest]
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
@ -37,7 +41,7 @@ jobs:
|
||||
run: pnpm ci
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
if: ${{ env.CODECOV_TOKEN != '' && matrix.os == 'depot-ubuntu-24.04' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
with:
|
||||
directory: cli/coverage
|
||||
|
||||
@ -27,7 +27,7 @@ COPY api/providers ./providers
|
||||
COPY dify-agent/pyproject.toml dify-agent/README.md /app/dify-agent/
|
||||
COPY dify-agent/src /app/dify-agent/src
|
||||
# Trust the checked-in lock during image builds; local path sources are copied from the repository context.
|
||||
RUN uv sync --frozen --no-dev
|
||||
RUN uv sync --frozen --no-dev --no-editable
|
||||
|
||||
# production stage
|
||||
FROM base AS production
|
||||
|
||||
@ -223,10 +223,11 @@ def initialize_extensions(app: DifyApp):
|
||||
|
||||
def create_migrations_app() -> DifyApp:
|
||||
app = create_flask_app_with_configs()
|
||||
from extensions import ext_database, ext_migrate
|
||||
from extensions import ext_commands, ext_database, ext_migrate
|
||||
|
||||
# Initialize only required extensions
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init_app(app)
|
||||
ext_commands.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
@ -31,18 +31,22 @@ from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAge
|
||||
from clients.agent_backend.request_builder import (
|
||||
AGENT_SOUL_PROMPT_LAYER_ID,
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID,
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID,
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID,
|
||||
AgentBackendModelConfig,
|
||||
AgentBackendOutputConfig,
|
||||
AgentBackendRunRequestBuilder,
|
||||
AgentBackendWorkflowNodeRunInput,
|
||||
CleanupLayerSpec,
|
||||
extract_cleanup_layer_specs,
|
||||
redact_for_agent_backend_log,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AGENT_SOUL_PROMPT_LAYER_ID",
|
||||
"DIFY_EXECUTION_CONTEXT_LAYER_ID",
|
||||
"DIFY_PLUGIN_TOOLS_LAYER_ID",
|
||||
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
|
||||
"WORKFLOW_USER_PROMPT_LAYER_ID",
|
||||
"AgentBackendError",
|
||||
@ -66,9 +70,11 @@ __all__ = [
|
||||
"AgentBackendTransportError",
|
||||
"AgentBackendValidationError",
|
||||
"AgentBackendWorkflowNodeRunInput",
|
||||
"CleanupLayerSpec",
|
||||
"DifyAgentBackendRunClient",
|
||||
"FakeAgentBackendRunClient",
|
||||
"FakeAgentBackendScenario",
|
||||
"create_agent_backend_run_client",
|
||||
"extract_cleanup_layer_specs",
|
||||
"redact_for_agent_backend_log",
|
||||
]
|
||||
|
||||
@ -20,6 +20,8 @@ from dify_agent.protocol import (
|
||||
RunEvent,
|
||||
RunFailedEvent,
|
||||
RunFailedEventData,
|
||||
RunPausedEvent,
|
||||
RunPausedEventData,
|
||||
RunStartedEvent,
|
||||
RunStatusResponse,
|
||||
RunSucceededEvent,
|
||||
@ -34,6 +36,7 @@ class FakeAgentBackendScenario(StrEnum):
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class FakeAgentBackendRunClient:
|
||||
@ -89,6 +92,13 @@ class FakeAgentBackendRunClient:
|
||||
updated_at=_FIXED_TIME,
|
||||
error="fake failure",
|
||||
)
|
||||
case FakeAgentBackendScenario.PAUSED:
|
||||
return RunStatusResponse(
|
||||
run_id=run_id,
|
||||
status="paused",
|
||||
created_at=_FIXED_TIME,
|
||||
updated_at=_FIXED_TIME,
|
||||
)
|
||||
|
||||
def _events(self, run_id: str) -> tuple[RunEvent, ...]:
|
||||
match self.scenario:
|
||||
@ -115,3 +125,17 @@ class FakeAgentBackendRunClient:
|
||||
data=RunFailedEventData(error="fake failure", reason="unit_test"),
|
||||
),
|
||||
)
|
||||
case FakeAgentBackendScenario.PAUSED:
|
||||
return (
|
||||
RunStartedEvent(id="1-0", run_id=run_id, created_at=_FIXED_TIME),
|
||||
RunPausedEvent(
|
||||
id="2-0",
|
||||
run_id=run_id,
|
||||
created_at=_FIXED_TIME,
|
||||
data=RunPausedEventData(
|
||||
reason="human_input_required",
|
||||
message="Agent requested human input.",
|
||||
session_snapshot=CompositorSessionSnapshot(layers=[]),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@ -11,15 +11,19 @@ composition-driven.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, cast
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from agenton.compositor.schemas import LayerSessionSnapshot
|
||||
from agenton.layers import ExitIntent
|
||||
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
|
||||
from agenton_collections.layers.pydantic_ai import PYDANTIC_AI_HISTORY_LAYER_TYPE_ID
|
||||
from dify_agent.layers.dify_plugin import (
|
||||
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
DifyPluginCredentialValue,
|
||||
DifyPluginLLMLayerConfig,
|
||||
DifyPluginToolsLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.execution_context import (
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
|
||||
@ -27,6 +31,7 @@ from dify_agent.layers.execution_context import (
|
||||
)
|
||||
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
|
||||
from dify_agent.protocol import (
|
||||
DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
DIFY_AGENT_MODEL_LAYER_ID,
|
||||
DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
CreateRunRequest,
|
||||
@ -41,6 +46,85 @@ AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID = "tools"
|
||||
|
||||
# Layer types that hold credentials in their per-run config. These are excluded
|
||||
# from the cleanup-replay composition (and from the snapshot that is sent with
|
||||
# the cleanup request) because we deliberately do not persist plaintext
|
||||
# credentials between runs.
|
||||
_CLEANUP_EXCLUDED_LAYER_TYPES: tuple[str, ...] = (
|
||||
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
)
|
||||
|
||||
|
||||
class CleanupLayerSpec(BaseModel):
|
||||
"""One layer node replayed by an Agent backend cleanup-only run.
|
||||
|
||||
Cleanup composition cannot include credential-bearing plugin layers, so we
|
||||
persist only the non-plugin layer specs together with the original config.
|
||||
Storing the config (rather than just ``name``/``type``) means cleanup does
|
||||
not depend on the original build-time inputs being re-derivable.
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
deps: dict[str, str] = Field(default_factory=dict)
|
||||
metadata: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
config: JsonValue = None
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
def extract_cleanup_layer_specs(composition: RunComposition) -> list[CleanupLayerSpec]:
|
||||
"""Project the in-flight composition into the persistable cleanup spec list.
|
||||
|
||||
Plugin layers are intentionally dropped (their configs hold credentials and
|
||||
the lifecycle contract says "do not include an LLM layer" during cleanup).
|
||||
The filtered names must later drive snapshot filtering so the agenton
|
||||
compositor's name-order check still passes for the cleanup run.
|
||||
"""
|
||||
excluded = set(_CLEANUP_EXCLUDED_LAYER_TYPES)
|
||||
specs: list[CleanupLayerSpec] = []
|
||||
for layer in composition.layers:
|
||||
if layer.type in excluded:
|
||||
continue
|
||||
config_value: JsonValue = None
|
||||
if isinstance(layer.config, BaseModel):
|
||||
config_value = layer.config.model_dump(mode="json", warnings=False)
|
||||
else:
|
||||
# ``RunLayerSpec.config`` is typed as ``LayerConfigInput`` which
|
||||
# includes ``Mapping[str, object] | bytes``. In the cleanup-replay
|
||||
# pipeline our builder only emits BaseModel-derived configs or
|
||||
# ``None``, so the wider input alias narrows safely here.
|
||||
config_value = cast(JsonValue, layer.config)
|
||||
specs.append(
|
||||
CleanupLayerSpec(
|
||||
name=layer.name,
|
||||
type=layer.type,
|
||||
deps=dict(layer.deps),
|
||||
metadata=dict(layer.metadata),
|
||||
config=config_value,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
|
||||
def _filter_snapshot_to_specs(
|
||||
snapshot: CompositorSessionSnapshot,
|
||||
specs: list[CleanupLayerSpec],
|
||||
) -> CompositorSessionSnapshot:
|
||||
"""Keep only snapshot layers whose names appear in the cleanup spec list.
|
||||
|
||||
The agenton compositor rejects a snapshot whose layer-name sequence does
|
||||
not match the active composition exactly. Cleanup-replay drops plugin
|
||||
layers, so we must drop the matching snapshot entries here.
|
||||
"""
|
||||
kept_names = {spec.name for spec in specs}
|
||||
filtered_layers: list[LayerSessionSnapshot] = [layer for layer in snapshot.layers if layer.name in kept_names]
|
||||
if len(filtered_layers) == len(snapshot.layers):
|
||||
return snapshot
|
||||
return CompositorSessionSnapshot(schema_version=snapshot.schema_version, layers=filtered_layers)
|
||||
|
||||
|
||||
class AgentBackendModelConfig(BaseModel):
|
||||
@ -81,8 +165,10 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
purpose: RunPurpose = "workflow_node"
|
||||
idempotency_key: str | None = None
|
||||
output: AgentBackendOutputConfig | None = None
|
||||
tools: DifyPluginToolsLayerConfig | None = None
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
suspend_on_exit: bool = False
|
||||
include_history: bool = True
|
||||
suspend_on_exit: bool = True
|
||||
metadata: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
@ -98,6 +184,50 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
class AgentBackendRunRequestBuilder:
|
||||
"""Converts API product state into the public ``dify-agent`` run protocol."""
|
||||
|
||||
def build_cleanup_request(
|
||||
self,
|
||||
*,
|
||||
session_snapshot: CompositorSessionSnapshot,
|
||||
composition_layer_specs: list[CleanupLayerSpec],
|
||||
idempotency_key: str | None = None,
|
||||
metadata: dict[str, JsonValue] | None = None,
|
||||
) -> CreateRunRequest:
|
||||
"""Build a lifecycle-only cleanup request that replays the prior layers.
|
||||
|
||||
The agenton compositor enforces that the session snapshot's layer names
|
||||
match the active composition in order, so cleanup must replay the same
|
||||
non-plugin layer graph that produced the snapshot. Plugin layers
|
||||
(``dify.plugin.llm``, ``dify.plugin.tools``) are excluded from both the
|
||||
composition and the snapshot before submission because their configs
|
||||
require credentials that are not persisted between runs.
|
||||
"""
|
||||
if not composition_layer_specs:
|
||||
raise ValueError(
|
||||
"build_cleanup_request requires composition_layer_specs; an empty "
|
||||
"composition would fail the agent backend's snapshot validation."
|
||||
)
|
||||
request_metadata = dict(metadata or {})
|
||||
request_metadata["agent_backend_lifecycle"] = "session_cleanup"
|
||||
layers = [
|
||||
RunLayerSpec(
|
||||
name=spec.name,
|
||||
type=spec.type,
|
||||
deps=dict(spec.deps),
|
||||
metadata=dict(spec.metadata),
|
||||
config=spec.config,
|
||||
)
|
||||
for spec in composition_layer_specs
|
||||
]
|
||||
filtered_snapshot = _filter_snapshot_to_specs(session_snapshot, composition_layer_specs)
|
||||
return CreateRunRequest(
|
||||
composition=RunComposition(layers=layers),
|
||||
purpose="workflow_node",
|
||||
idempotency_key=idempotency_key,
|
||||
metadata=request_metadata,
|
||||
session_snapshot=filtered_snapshot,
|
||||
on_exit=LayerExitSignals(default=ExitIntent.DELETE),
|
||||
)
|
||||
|
||||
def build_for_workflow_node(self, run_input: AgentBackendWorkflowNodeRunInput) -> CreateRunRequest:
|
||||
"""Build a workflow Agent Node run request without defining another wire schema."""
|
||||
layers: list[RunLayerSpec] = []
|
||||
@ -131,6 +261,20 @@ class AgentBackendRunRequestBuilder:
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.execution_context,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if run_input.include_history:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_session_history"},
|
||||
)
|
||||
)
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_MODEL_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
@ -147,6 +291,17 @@ class AgentBackendRunRequestBuilder:
|
||||
]
|
||||
)
|
||||
|
||||
if run_input.tools is not None and run_input.tools.tools:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_PLUGIN_TOOLS_LAYER_ID,
|
||||
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.tools,
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.output is not None:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
|
||||
@ -3,6 +3,13 @@ CLI command modules extracted from `commands.py`.
|
||||
"""
|
||||
|
||||
from .account import create_tenant, reset_email, reset_password
|
||||
from .data_migrate import data_migrate, legacy_model_types
|
||||
from .data_migration import (
|
||||
export_migration_data,
|
||||
export_migration_data_template,
|
||||
import_migration_data,
|
||||
migration_data_wizard,
|
||||
)
|
||||
from .plugin import (
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
@ -25,7 +32,12 @@ from .retention import (
|
||||
restore_workflow_runs,
|
||||
)
|
||||
from .storage import clear_orphaned_file_records, file_usage, migrate_oss, remove_orphaned_files_on_storage
|
||||
from .system import convert_to_agent_apps, fix_app_site_missing, reset_encrypt_key_pair, upgrade_db
|
||||
from .system import (
|
||||
convert_to_agent_apps,
|
||||
fix_app_site_missing,
|
||||
reset_encrypt_key_pair,
|
||||
upgrade_db,
|
||||
)
|
||||
from .vector import (
|
||||
add_qdrant_index,
|
||||
migrate_annotation_vector_database,
|
||||
@ -44,18 +56,24 @@ __all__ = [
|
||||
"clear_orphaned_file_records",
|
||||
"convert_to_agent_apps",
|
||||
"create_tenant",
|
||||
"data_migrate",
|
||||
"delete_archived_workflow_runs",
|
||||
"export_app_messages",
|
||||
"export_migration_data",
|
||||
"export_migration_data_template",
|
||||
"extract_plugins",
|
||||
"extract_unique_plugins",
|
||||
"file_usage",
|
||||
"fix_app_site_missing",
|
||||
"import_migration_data",
|
||||
"install_plugins",
|
||||
"install_rag_pipeline_plugins",
|
||||
"legacy_model_types",
|
||||
"migrate_annotation_vector_database",
|
||||
"migrate_data_for_plugin",
|
||||
"migrate_knowledge_vector_database",
|
||||
"migrate_oss",
|
||||
"migration_data_wizard",
|
||||
"old_metadata_migration",
|
||||
"remove_orphaned_files_on_storage",
|
||||
"reset_email",
|
||||
|
||||
179
api/commands/data_migrate.py
Normal file
179
api/commands/data_migrate.py
Normal file
@ -0,0 +1,179 @@
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import click
|
||||
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from services.legacy_model_type_migration import (
|
||||
VALID_TABLE_NAMES,
|
||||
LegacyModelTypeMigrationService,
|
||||
load_tenant_ids_from_file,
|
||||
)
|
||||
|
||||
_SUPPORTED_MODEL_TYPE_CHOICES = (
|
||||
ModelType.LLM.value,
|
||||
ModelType.TEXT_EMBEDDING.value,
|
||||
ModelType.RERANK.value,
|
||||
)
|
||||
_DEFAULT_CONCURRENCY = os.cpu_count() or 1
|
||||
|
||||
|
||||
def _normalize_multi_value_option(
|
||||
values: tuple[str, ...],
|
||||
*,
|
||||
valid_values: tuple[str, ...],
|
||||
option_name: str,
|
||||
) -> tuple[str, ...]:
|
||||
normalized_values: list[str] = []
|
||||
seen_values: set[str] = set()
|
||||
|
||||
for value in values:
|
||||
for item in value.split(","):
|
||||
normalized_item = item.strip()
|
||||
if not normalized_item:
|
||||
continue
|
||||
if normalized_item not in valid_values:
|
||||
raise click.BadParameter(
|
||||
f"invalid value '{normalized_item}'. valid values: {', '.join(valid_values)}",
|
||||
param_hint=option_name,
|
||||
)
|
||||
if normalized_item in seen_values:
|
||||
continue
|
||||
seen_values.add(normalized_item)
|
||||
normalized_values.append(normalized_item)
|
||||
|
||||
return tuple(normalized_values)
|
||||
|
||||
|
||||
@click.group(
|
||||
"data-migrate",
|
||||
help="Online data migration commands.",
|
||||
)
|
||||
def data_migrate() -> None:
|
||||
"""Namespace for production data migration commands."""
|
||||
|
||||
|
||||
@click.command(
|
||||
"legacy-model-types",
|
||||
help=(
|
||||
"Migrate legacy provider model_type values to canonical values. "
|
||||
"Default is dry-run and emits JSON lines only. "
|
||||
"If --tables includes provider_model_credentials, the command may also update "
|
||||
"provider_models and load_balancing_model_configs references so merged credentials stay reachable."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--apply",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Apply the migration. Default is dry-run.",
|
||||
)
|
||||
@click.option(
|
||||
"--tables",
|
||||
"tables",
|
||||
multiple=True,
|
||||
type=str,
|
||||
help=(
|
||||
"Limit migration to specific tables. Accepts comma-separated values or repeated flags.\n"
|
||||
"\n"
|
||||
"Options: load_balancing_model_configs, provider_model_credentials, "
|
||||
"provider_model_settings, provider_models, tenant_default_models.\n\n"
|
||||
"When provider_model_credentials is selected, provider_models and "
|
||||
"load_balancing_model_configs may also be updated for credential reference rewrites.\n"
|
||||
"\n"
|
||||
"If unspecified, all relevant tables are migrated."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--model-types",
|
||||
"model_types",
|
||||
multiple=True,
|
||||
type=str,
|
||||
help=(
|
||||
"Canonical model types to migrate. Accepts comma-separated values or repeated flags.\n"
|
||||
"\n"
|
||||
"Options: llm,text-embedding,rerank\n"
|
||||
"\n"
|
||||
"If unspecified, all relevant legacy model types are migrated."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--tenant-id-file",
|
||||
type=click.Path(exists=True, dir_okay=False, readable=True, resolve_path=True),
|
||||
help="Optional file containing tenant ids, one per line.",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
type=click.Path(dir_okay=False, resolve_path=True, path_type=Path),
|
||||
help=(
|
||||
"Optional file path for JSON lines event logs. Defaults to stdout.\n"
|
||||
"It's highly recommended to save the event logs to a file and preserve it for a period of time."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--concurrency",
|
||||
type=click.IntRange(min=1),
|
||||
default=_DEFAULT_CONCURRENCY,
|
||||
show_default=True,
|
||||
help="Number of tenant-level worker threads to run in parallel.",
|
||||
)
|
||||
def legacy_model_types(
|
||||
apply: bool,
|
||||
tables: tuple[str, ...],
|
||||
model_types: tuple[str, ...],
|
||||
tenant_id_file: str | None,
|
||||
output: Path | None,
|
||||
concurrency: int = _DEFAULT_CONCURRENCY,
|
||||
) -> None:
|
||||
"""
|
||||
Migrate legacy provider-related model_type values and emit JSON lines events.
|
||||
"""
|
||||
|
||||
normalized_tables = _normalize_multi_value_option(
|
||||
tables,
|
||||
valid_values=VALID_TABLE_NAMES,
|
||||
option_name="--tables",
|
||||
)
|
||||
normalized_model_types = _normalize_multi_value_option(
|
||||
model_types,
|
||||
valid_values=_SUPPORTED_MODEL_TYPE_CHOICES,
|
||||
option_name="--model-types",
|
||||
)
|
||||
selected_model_types = (
|
||||
tuple(ModelType.value_of(model_type) for model_type in normalized_model_types)
|
||||
if normalized_model_types
|
||||
else (
|
||||
ModelType.LLM,
|
||||
ModelType.TEXT_EMBEDDING,
|
||||
ModelType.RERANK,
|
||||
)
|
||||
)
|
||||
tenant_ids = load_tenant_ids_from_file(tenant_id_file) if tenant_id_file else None
|
||||
|
||||
output_context: AbstractContextManager[io.TextIOBase]
|
||||
if output is None:
|
||||
output_context = nullcontext(cast(io.TextIOBase, sys.stdout))
|
||||
else:
|
||||
try:
|
||||
output_context = output.open("w", encoding="utf-8")
|
||||
except OSError as exc:
|
||||
raise click.ClickException(f"failed to open output file '{output}': {exc.strerror or exc}") from exc
|
||||
|
||||
with output_context as output_stream:
|
||||
LegacyModelTypeMigrationService(
|
||||
engine=db.engine,
|
||||
apply=apply,
|
||||
concurrency=concurrency,
|
||||
output=cast(io.TextIOBase, output_stream),
|
||||
tables=normalized_tables or None,
|
||||
model_types=selected_model_types,
|
||||
tenant_ids=tenant_ids,
|
||||
).migrate()
|
||||
|
||||
|
||||
data_migrate.add_command(legacy_model_types)
|
||||
754
api/commands/data_migration.py
Normal file
754
api/commands/data_migration.py
Normal file
@ -0,0 +1,754 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
import yaml
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.model import App
|
||||
from models.tools import ApiToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.data_migration.dependency_discovery_service import DependencyDiscoveryService
|
||||
from services.data_migration.entities import (
|
||||
DependencyKind,
|
||||
ImportOptions,
|
||||
MigrationDataError,
|
||||
ReportContext,
|
||||
ResourceReportItem,
|
||||
)
|
||||
from services.data_migration.export_service import ExportConfigParser, MigrationExportService
|
||||
from services.data_migration.import_service import ImportRequest, MigrationImportService
|
||||
from services.data_migration.package_service import MigrationPackageService
|
||||
from services.data_migration.report_service import MigrationReportService
|
||||
|
||||
ID_STRATEGY_CHOICES = ["preserve-id", "generate-new-id"]
|
||||
CONFLICT_STRATEGY_CHOICES = ["fail", "skip", "update"]
|
||||
SUPPORTED_WIZARD_APP_MODES = ["workflow", "advanced-chat"]
|
||||
WizardToolMap = dict[str, dict[str, str | None]]
|
||||
WizardToolSelection = dict[str, list[str]]
|
||||
|
||||
|
||||
def _scripted_export_template() -> dict[str, Any]:
|
||||
return {
|
||||
"source_tenant": {
|
||||
"mode": "single",
|
||||
"id": "",
|
||||
"name": "admin's Workspace",
|
||||
},
|
||||
"apps": {
|
||||
"modes": ["workflow", "advanced-chat"],
|
||||
"ids": [],
|
||||
"all": True,
|
||||
},
|
||||
"include_referenced_tools": True,
|
||||
"additional_tools": {
|
||||
"api_tools": [],
|
||||
"workflow_tools": [],
|
||||
"mcp_tools": [],
|
||||
},
|
||||
"include_secrets": False,
|
||||
"import_options": {
|
||||
"create_app_api_token_on_import": False,
|
||||
"id_strategy": "preserve-id",
|
||||
"conflict_strategy": "fail",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@click.command("app-migration-template", help="Print or write a scripted export config JSON template.")
|
||||
@click.option(
|
||||
"--output",
|
||||
"output_file",
|
||||
required=False,
|
||||
type=click.Path(dir_okay=False),
|
||||
help="Path to write the export config JSON template. Prints to stdout when omitted.",
|
||||
)
|
||||
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite output if it already exists.")
|
||||
def export_migration_data_template(output_file: str | None, overwrite: bool) -> None:
|
||||
template_json = json.dumps(_scripted_export_template(), indent=2, ensure_ascii=False) + "\n"
|
||||
if output_file is None:
|
||||
click.echo(template_json, nl=False)
|
||||
return
|
||||
path = Path(output_file)
|
||||
if path.exists() and not overwrite:
|
||||
raise click.ClickException(f"Output file already exists: {output_file}")
|
||||
path.write_text(template_json)
|
||||
click.echo(click.style(f"Output written to {output_file}", fg="green"))
|
||||
|
||||
|
||||
@click.command("export-app-migration", help="Export workflow migration data to a versioned JSON package.")
|
||||
@click.option(
|
||||
"--input",
|
||||
"input_file",
|
||||
required=False,
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
help="Path to export config JSON.",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
"output_file",
|
||||
required=False,
|
||||
type=click.Path(dir_okay=False),
|
||||
help="Path to migration package JSON.",
|
||||
)
|
||||
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite output if it already exists.")
|
||||
def export_migration_data(input_file: str | None, output_file: str | None, overwrite: bool) -> None:
|
||||
try:
|
||||
_require_options(("--input", input_file), ("--output", output_file))
|
||||
assert input_file is not None
|
||||
assert output_file is not None
|
||||
raw_config = _load_json_object(input_file, "Export config")
|
||||
selection = ExportConfigParser().parse(raw_config)
|
||||
result = MigrationExportService().export(selection)
|
||||
MigrationPackageService().save_package(result.package, output_file, overwrite=overwrite)
|
||||
click.echo(click.style(f"Output written to {output_file}", fg="green"))
|
||||
_render_report(result.report_items, context=_with_output_path(result.report_context, output_file))
|
||||
except MigrationDataError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
|
||||
|
||||
@click.command("import-app-migration", help="Import a versioned migration data package.")
|
||||
@click.option(
|
||||
"--input",
|
||||
"input_file",
|
||||
required=False,
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
help="Path to migration package JSON.",
|
||||
)
|
||||
@click.option("--target-tenant", default=None, help="Target tenant/workspace name. Overrides package metadata.")
|
||||
@click.option("--operator-email", default=None, help="Operator account email in the target tenant.")
|
||||
@click.option(
|
||||
"--id-strategy",
|
||||
default=None,
|
||||
type=click.Choice(ID_STRATEGY_CHOICES),
|
||||
help="Override package ID strategy.",
|
||||
)
|
||||
@click.option(
|
||||
"--conflict-strategy",
|
||||
default=None,
|
||||
type=click.Choice(CONFLICT_STRATEGY_CHOICES),
|
||||
help="Override package conflict strategy.",
|
||||
)
|
||||
@click.option(
|
||||
"--create-app-api-token-on-import/--no-create-app-api-token-on-import",
|
||||
default=None,
|
||||
help="Override package app API token creation behavior.",
|
||||
)
|
||||
def import_migration_data(
|
||||
input_file: str | None,
|
||||
target_tenant: str | None,
|
||||
operator_email: str | None,
|
||||
id_strategy: str | None,
|
||||
conflict_strategy: str | None,
|
||||
create_app_api_token_on_import: bool | None,
|
||||
) -> None:
|
||||
try:
|
||||
_require_options(("--input", input_file))
|
||||
assert input_file is not None
|
||||
package = MigrationPackageService().load_package(input_file)
|
||||
result = MigrationImportService().import_package(
|
||||
ImportRequest(
|
||||
package=package,
|
||||
cli_target_tenant=target_tenant,
|
||||
operator_email=operator_email,
|
||||
options_override=_build_options_override(
|
||||
package.metadata.import_options,
|
||||
id_strategy=id_strategy,
|
||||
conflict_strategy=conflict_strategy,
|
||||
create_app_api_token_on_import=create_app_api_token_on_import,
|
||||
),
|
||||
)
|
||||
)
|
||||
_render_report(result.report_items, context=result.report_context)
|
||||
except MigrationDataError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
|
||||
|
||||
def parse_index_selection(raw: str, values: list[str]) -> list[str]:
|
||||
normalized = raw.strip().lower()
|
||||
if normalized == "all":
|
||||
return values
|
||||
|
||||
selected: list[str] = []
|
||||
for part in raw.split(","):
|
||||
stripped = part.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
try:
|
||||
index = int(stripped)
|
||||
except ValueError as exc:
|
||||
raise click.ClickException(f"Selection must be 'all' or comma-separated numbers: {raw}") from exc
|
||||
if index < 1 or index > len(values):
|
||||
raise click.ClickException(f"Selection index out of range: {index}")
|
||||
selected.append(values[index - 1])
|
||||
return list(dict.fromkeys(selected))
|
||||
|
||||
|
||||
def _print_wizard_step(title: str) -> None:
|
||||
click.echo("")
|
||||
click.echo(f"==== {title} ====")
|
||||
|
||||
|
||||
def _print_wizard_substep(title: str) -> None:
|
||||
click.echo("")
|
||||
click.echo(f"-- {title} --")
|
||||
|
||||
|
||||
@click.command("app-migration-wizard", help="Interactively export workflow migration data.")
|
||||
def migration_data_wizard() -> None:
|
||||
try:
|
||||
tenant = _prompt_source_tenant()
|
||||
apps = _eligible_apps_for_tenant(tenant.id)
|
||||
app_ids = _prompt_app_ids(apps)
|
||||
_print_wizard_step("Referenced Tools")
|
||||
include_referenced_tools = click.confirm(
|
||||
"Automatically export tools referenced by selected apps? [y/n, default: y]",
|
||||
default=True,
|
||||
show_default=False,
|
||||
)
|
||||
auto_tools = _discover_auto_tools([app for app in apps if app.id in set(app_ids)], include_referenced_tools)
|
||||
auto_tools = _resolve_auto_tool_names(tenant.id, auto_tools)
|
||||
_print_auto_tools(auto_tools)
|
||||
additional_tools = _prompt_additional_tools(tenant.id, auto_tools)
|
||||
include_secrets, create_tokens, id_strategy, conflict_strategy = _prompt_import_options()
|
||||
_print_wizard_step("Output")
|
||||
output_file, overwrite = _prompt_output_file()
|
||||
|
||||
selection = ExportConfigParser().parse(
|
||||
{
|
||||
"source_tenant": {"mode": "single", "id": tenant.id, "name": tenant.name},
|
||||
"apps": {"ids": app_ids, "all": False},
|
||||
"include_referenced_tools": include_referenced_tools,
|
||||
"additional_tools": additional_tools,
|
||||
"include_secrets": include_secrets,
|
||||
"import_options": {
|
||||
"create_app_api_token_on_import": create_tokens,
|
||||
"id_strategy": id_strategy,
|
||||
"conflict_strategy": conflict_strategy,
|
||||
},
|
||||
}
|
||||
)
|
||||
_confirm_wizard_summary(
|
||||
tenant_name=tenant.name,
|
||||
app_names=[app.name for app in apps if app.id in set(app_ids)],
|
||||
auto_tools=auto_tools,
|
||||
additional_tools=additional_tools,
|
||||
manual_labels=_selected_tool_labels_for_tenant(tenant.id, additional_tools),
|
||||
include_referenced_tools=include_referenced_tools,
|
||||
include_secrets=include_secrets,
|
||||
create_tokens=create_tokens,
|
||||
id_strategy=id_strategy,
|
||||
conflict_strategy=conflict_strategy,
|
||||
output_file=output_file,
|
||||
)
|
||||
result = MigrationExportService().export(selection)
|
||||
MigrationPackageService().save_package(result.package, output_file, overwrite=overwrite)
|
||||
click.echo(click.style(f"Output written to {output_file}", fg="green"))
|
||||
_print_wizard_step("Report")
|
||||
_render_report(result.report_items, context=_with_output_path(result.report_context, output_file))
|
||||
except MigrationDataError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
|
||||
|
||||
def _load_json_object(path: str, label: str) -> dict[str, Any]:
|
||||
try:
|
||||
with Path(path).open(encoding="utf-8") as file:
|
||||
raw = json.load(file)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise MigrationDataError(f"{label} JSON is invalid: {exc.msg}") from exc
|
||||
if not isinstance(raw, dict):
|
||||
raise MigrationDataError(f"{label} JSON must be an object.")
|
||||
return raw
|
||||
|
||||
|
||||
def _require_options(*options: tuple[str, object | None]) -> None:
|
||||
missing_options = [name for name, value in options if value is None]
|
||||
if missing_options:
|
||||
raise click.UsageError(f"Missing option(s): {', '.join(missing_options)}.")
|
||||
|
||||
|
||||
def _build_options_override(
|
||||
package_options: ImportOptions,
|
||||
*,
|
||||
id_strategy: str | None,
|
||||
conflict_strategy: str | None,
|
||||
create_app_api_token_on_import: bool | None,
|
||||
) -> ImportOptions | None:
|
||||
if id_strategy is None and conflict_strategy is None and create_app_api_token_on_import is None:
|
||||
return None
|
||||
return ImportOptions.from_mapping(
|
||||
{
|
||||
"id_strategy": id_strategy or package_options.id_strategy,
|
||||
"conflict_strategy": conflict_strategy or package_options.conflict_strategy,
|
||||
"create_app_api_token_on_import": (
|
||||
create_app_api_token_on_import
|
||||
if create_app_api_token_on_import is not None
|
||||
else package_options.create_app_api_token_on_import
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _prompt_source_tenant() -> Tenant:
|
||||
tenants = list(db.session.scalars(sa.select(Tenant).order_by(Tenant.name.asc())).all())
|
||||
if not tenants:
|
||||
raise MigrationDataError("No tenants found.")
|
||||
|
||||
_print_wizard_step("Source Tenant")
|
||||
click.echo("Source tenants:")
|
||||
for index, tenant in enumerate(tenants, 1):
|
||||
click.echo(f"{index}. {tenant.name} ({tenant.id})")
|
||||
|
||||
tenant_index = click.prompt("Select one source tenant by number", type=int, default=1, show_default=True)
|
||||
if tenant_index < 1 or tenant_index > len(tenants):
|
||||
raise click.ClickException(f"Selection index out of range: {tenant_index}")
|
||||
return tenants[tenant_index - 1]
|
||||
|
||||
|
||||
def _eligible_apps_for_tenant(tenant_id: str) -> list[App]:
|
||||
return list(
|
||||
db.session.scalars(
|
||||
sa.select(App)
|
||||
.where(App.tenant_id == tenant_id, App.mode.in_(SUPPORTED_WIZARD_APP_MODES))
|
||||
.order_by(App.name.asc())
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def _prompt_app_ids(apps: list[App]) -> list[str]:
|
||||
if not apps:
|
||||
raise MigrationDataError("No workflow or advanced-chat apps found for the selected tenant.")
|
||||
|
||||
_print_wizard_step("App Selection")
|
||||
click.echo("Currently supported app types: workflow and chatflow.")
|
||||
click.echo("Workflow/chatflow apps:")
|
||||
for index, app in enumerate(apps, 1):
|
||||
mode = app.mode.value if hasattr(app.mode, "value") else app.mode
|
||||
click.echo(f"{index}. {app.name} [{mode}] ({app.id})")
|
||||
app_ids = parse_index_selection(
|
||||
click.prompt("Select apps by number, comma-separated numbers, or all", default="all"),
|
||||
[app.id for app in apps],
|
||||
)
|
||||
selected_apps = [app for app in apps if app.id in set(app_ids)]
|
||||
click.echo("Selected apps:")
|
||||
for app in selected_apps:
|
||||
click.echo(f"- {app.name} ({app.id})")
|
||||
return app_ids
|
||||
|
||||
|
||||
def _prompt_import_options() -> tuple[bool, bool, str, str]:
|
||||
_print_wizard_step("Import Options")
|
||||
_print_wizard_substep("Secrets")
|
||||
click.echo("Secrets include workflow/app DSL secret values, custom API tool credentials,")
|
||||
click.echo("and full MCP provider connection data such as server URL, headers, authentication, and tool list.")
|
||||
click.echo("If you choose no, credentials are omitted or masked,")
|
||||
click.echo("and MCP providers are exported as dependency metadata only.")
|
||||
click.echo("Treat the output JSON as sensitive if you choose yes.")
|
||||
include_secrets = click.confirm(
|
||||
"Include secrets in output JSON? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
)
|
||||
_print_wizard_substep("App API Tokens")
|
||||
click.echo("When enabled, import will create an app API token if the imported app has none,")
|
||||
click.echo("or reuse an existing app API token if one already exists.")
|
||||
create_tokens = click.confirm(
|
||||
"Create or reuse app API tokens during import? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
)
|
||||
_print_wizard_substep("ID Strategy")
|
||||
click.echo("ID strategy controls whether imported app and tool IDs preserve source IDs")
|
||||
click.echo("or use target-generated IDs.")
|
||||
click.echo("preserve-id: keep source IDs where the target service supports it.")
|
||||
click.echo("generate-new-id: let the target environment generate new IDs and rewrite references via mapping.")
|
||||
id_strategy = click.prompt(
|
||||
"Import ID strategy. Enter one of: preserve-id, generate-new-id",
|
||||
type=click.Choice(ID_STRATEGY_CHOICES),
|
||||
default="preserve-id",
|
||||
show_default=True,
|
||||
)
|
||||
_print_wizard_substep("Conflict Strategy")
|
||||
click.echo("Conflict strategy controls what import does when a target resource already exists.")
|
||||
click.echo("fail: stop at the first conflict; previously committed resources are not rolled back.")
|
||||
click.echo("skip: keep the existing target resource and skip importing that resource.")
|
||||
click.echo("update: update the existing target resource in place.")
|
||||
conflict_strategy = click.prompt(
|
||||
"Import conflict strategy. Enter one of: fail, skip, update",
|
||||
type=click.Choice(CONFLICT_STRATEGY_CHOICES),
|
||||
default="update",
|
||||
show_default=True,
|
||||
)
|
||||
return include_secrets, create_tokens, id_strategy, conflict_strategy
|
||||
|
||||
|
||||
def _discover_auto_tools(apps: list[App], include_referenced_tools: bool) -> WizardToolMap:
|
||||
auto_tools: WizardToolMap = {"api_tools": {}, "workflow_tools": {}, "mcp_tools": {}}
|
||||
if not include_referenced_tools:
|
||||
return auto_tools
|
||||
discovery_service = DependencyDiscoveryService()
|
||||
for app in apps:
|
||||
dsl_content = AppDslService.export_dsl(app_model=app, include_secret=False)
|
||||
raw_dsl = yaml.safe_load(dsl_content) if dsl_content else {}
|
||||
dsl = raw_dsl if isinstance(raw_dsl, dict) else {}
|
||||
for dependency in discovery_service.discover_from_dsl(dsl):
|
||||
if dependency.kind == DependencyKind.API_TOOL:
|
||||
auto_tools["api_tools"][dependency.provider_name or dependency.provider_id] = dependency.provider_id
|
||||
elif dependency.kind == DependencyKind.WORKFLOW_TOOL:
|
||||
auto_tools["workflow_tools"][dependency.provider_name or dependency.provider_id] = (
|
||||
dependency.provider_id
|
||||
)
|
||||
elif dependency.kind == DependencyKind.MCP_TOOL:
|
||||
auto_tools["mcp_tools"][dependency.provider_name or dependency.provider_id] = dependency.provider_id
|
||||
return auto_tools
|
||||
|
||||
|
||||
def _resolve_auto_tool_names(tenant_id: str, auto_tools: WizardToolMap) -> WizardToolMap:
|
||||
return {
|
||||
"api_tools": _resolve_api_tool_names(tenant_id, auto_tools["api_tools"]),
|
||||
"workflow_tools": _resolve_workflow_tool_names(tenant_id, auto_tools["workflow_tools"]),
|
||||
"mcp_tools": _resolve_mcp_tool_names(tenant_id, auto_tools["mcp_tools"]),
|
||||
}
|
||||
|
||||
|
||||
def _resolve_api_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
|
||||
resolved: dict[str, str | None] = {}
|
||||
for name, identifier in tools.items():
|
||||
predicates = [ApiToolProvider.name == name]
|
||||
if _is_uuid_string(identifier):
|
||||
predicates.append(ApiToolProvider.id == identifier)
|
||||
provider = db.session.scalar(
|
||||
sa.select(ApiToolProvider).where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
sa.or_(*predicates),
|
||||
)
|
||||
)
|
||||
resolved[provider.name if provider else name] = provider.id if provider else identifier
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_workflow_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
|
||||
resolved: dict[str, str | None] = {}
|
||||
for name, identifier in tools.items():
|
||||
predicates = [WorkflowToolProvider.name == name]
|
||||
if _is_uuid_string(identifier):
|
||||
predicates.append(WorkflowToolProvider.id == identifier)
|
||||
provider = db.session.scalar(
|
||||
sa.select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
sa.or_(*predicates),
|
||||
)
|
||||
)
|
||||
resolved[provider.name if provider else name] = provider.id if provider else identifier
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_mcp_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
|
||||
resolved: dict[str, str | None] = {}
|
||||
for name, identifier in tools.items():
|
||||
predicates = [MCPToolProvider.name == name]
|
||||
if identifier:
|
||||
predicates.append(MCPToolProvider.server_identifier == identifier)
|
||||
if _is_uuid_string(identifier):
|
||||
predicates.append(MCPToolProvider.id == identifier)
|
||||
provider = db.session.scalar(
|
||||
sa.select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
sa.or_(*predicates),
|
||||
)
|
||||
)
|
||||
resolved[provider.name if provider else name] = provider.id if provider else identifier
|
||||
return resolved
|
||||
|
||||
|
||||
def _is_uuid_string(value: str | None) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _print_auto_tools(auto_tools: WizardToolMap) -> None:
|
||||
_print_wizard_step("Automatically Discovered Tools")
|
||||
click.echo("Automatically discovered tools:")
|
||||
_print_auto_tool_category("Custom API tools", auto_tools["api_tools"])
|
||||
_print_auto_tool_category("Workflow tools", auto_tools["workflow_tools"])
|
||||
_print_auto_tool_category("MCP tools", auto_tools["mcp_tools"])
|
||||
|
||||
|
||||
def _print_auto_tool_category(label: str, values: dict[str, str | None]) -> None:
|
||||
click.echo(label)
|
||||
if not values:
|
||||
click.echo("- none")
|
||||
return
|
||||
for name, identifier in sorted(values.items()):
|
||||
click.echo(f"- {_format_tool_name_id(name, identifier)}")
|
||||
|
||||
|
||||
def _prompt_additional_tools(tenant_id: str, auto_tools: WizardToolMap) -> WizardToolSelection:
|
||||
selections: WizardToolSelection = {"api_tools": [], "workflow_tools": [], "mcp_tools": []}
|
||||
_print_wizard_step("Additional Tools")
|
||||
if not click.confirm(
|
||||
"Export additional tools manually? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
):
|
||||
_print_final_tool_selection(auto_tools, selections, {})
|
||||
return selections
|
||||
manual_labels: dict[str, str] = {}
|
||||
api_tool_options = [
|
||||
(tool.name, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).order_by(ApiToolProvider.name)
|
||||
).all()
|
||||
]
|
||||
selections["api_tools"] = _prompt_tool_category(
|
||||
"Custom API tools",
|
||||
api_tool_options,
|
||||
auto_tools=auto_tools["api_tools"],
|
||||
)
|
||||
manual_labels.update(_selected_tool_labels(api_tool_options, selections["api_tools"]))
|
||||
workflow_tool_options = [
|
||||
(tool.id, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
.order_by(WorkflowToolProvider.name)
|
||||
).all()
|
||||
]
|
||||
selections["workflow_tools"] = _prompt_tool_category(
|
||||
"Workflow tools",
|
||||
workflow_tool_options,
|
||||
auto_tools=auto_tools["workflow_tools"],
|
||||
)
|
||||
manual_labels.update(_selected_tool_labels(workflow_tool_options, selections["workflow_tools"]))
|
||||
mcp_tool_options = [
|
||||
(tool.id, tool.name, tool.server_identifier)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
|
||||
).all()
|
||||
]
|
||||
selections["mcp_tools"] = _prompt_tool_category(
|
||||
"MCP tools",
|
||||
mcp_tool_options,
|
||||
auto_tools=auto_tools["mcp_tools"],
|
||||
)
|
||||
manual_labels.update(_selected_tool_labels(mcp_tool_options, selections["mcp_tools"]))
|
||||
_print_final_tool_selection(auto_tools, selections, manual_labels)
|
||||
return selections
|
||||
|
||||
|
||||
def _selected_tool_labels_for_tenant(tenant_id: str, selected_tools: WizardToolSelection) -> dict[str, str]:
|
||||
labels: dict[str, str] = {}
|
||||
if selected_tools["api_tools"]:
|
||||
labels.update(
|
||||
_selected_tool_labels(
|
||||
[
|
||||
(tool.name, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(ApiToolProvider)
|
||||
.where(ApiToolProvider.tenant_id == tenant_id)
|
||||
.order_by(ApiToolProvider.name)
|
||||
).all()
|
||||
],
|
||||
selected_tools["api_tools"],
|
||||
)
|
||||
)
|
||||
if selected_tools["workflow_tools"]:
|
||||
labels.update(
|
||||
_selected_tool_labels(
|
||||
[
|
||||
(tool.id, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
.order_by(WorkflowToolProvider.name)
|
||||
).all()
|
||||
],
|
||||
selected_tools["workflow_tools"],
|
||||
)
|
||||
)
|
||||
if selected_tools["mcp_tools"]:
|
||||
labels.update(
|
||||
_selected_tool_labels(
|
||||
[
|
||||
(tool.id, tool.name, tool.server_identifier)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(MCPToolProvider)
|
||||
.where(MCPToolProvider.tenant_id == tenant_id)
|
||||
.order_by(MCPToolProvider.name)
|
||||
).all()
|
||||
],
|
||||
selected_tools["mcp_tools"],
|
||||
)
|
||||
)
|
||||
return labels
|
||||
|
||||
|
||||
def _selected_tool_labels(options: list[tuple[str, str, str]], selected_values: list[str]) -> dict[str, str]:
|
||||
selected = set(selected_values)
|
||||
return {value: _format_tool_name_id(name, detail) for value, name, detail in options if value in selected}
|
||||
|
||||
|
||||
def _prompt_tool_category(
|
||||
label: str,
|
||||
options: list[tuple[str, str, str]],
|
||||
*,
|
||||
auto_tools: dict[str, str | None],
|
||||
) -> list[str]:
|
||||
if not options:
|
||||
click.echo(f"{label}: none")
|
||||
return []
|
||||
_print_wizard_step(label)
|
||||
for index, (value, name, detail) in enumerate(options, 1):
|
||||
marker = "[auto]" if _is_auto_tool(value, name, detail, auto_tools) else "[ ]"
|
||||
click.echo(f"{index}. {marker} {name} ({detail})")
|
||||
raw = click.prompt(
|
||||
f"Select {label.lower()} by number, comma-separated numbers, all, or empty",
|
||||
default="",
|
||||
show_default=cast(Any, "empty"),
|
||||
)
|
||||
if not raw.strip():
|
||||
return []
|
||||
return parse_index_selection(raw, [value for value, _, _ in options])
|
||||
|
||||
|
||||
def _is_auto_tool(value: str, name: str, detail: str, auto_tools: dict[str, str | None]) -> bool:
|
||||
return name in auto_tools or value in auto_tools or value in auto_tools.values() or detail in auto_tools.values()
|
||||
|
||||
|
||||
def _print_final_tool_selection(
|
||||
auto_tools: WizardToolMap,
|
||||
additional_tools: WizardToolSelection,
|
||||
manual_labels: dict[str, str],
|
||||
) -> None:
|
||||
_print_wizard_step("Final Tool Selection")
|
||||
_print_tool_selection_body(auto_tools, additional_tools, manual_labels)
|
||||
|
||||
|
||||
def _print_tool_selection_body(
|
||||
auto_tools: WizardToolMap,
|
||||
additional_tools: WizardToolSelection,
|
||||
manual_labels: dict[str, str],
|
||||
) -> None:
|
||||
click.echo("Final tools to export:")
|
||||
_print_final_tool_category(
|
||||
"Custom API tools",
|
||||
auto_tools["api_tools"],
|
||||
additional_tools["api_tools"],
|
||||
manual_labels,
|
||||
)
|
||||
_print_final_tool_category(
|
||||
"Workflow tools",
|
||||
auto_tools["workflow_tools"],
|
||||
additional_tools["workflow_tools"],
|
||||
manual_labels,
|
||||
)
|
||||
_print_final_tool_category("MCP tools", auto_tools["mcp_tools"], additional_tools["mcp_tools"], manual_labels)
|
||||
|
||||
|
||||
def _print_final_tool_category(
|
||||
label: str,
|
||||
auto_tools: dict[str, str | None],
|
||||
manual_values: list[str],
|
||||
manual_labels: dict[str, str],
|
||||
) -> None:
|
||||
click.echo(label)
|
||||
lines = [f"- [auto] {_format_tool_name_id(name, identifier)}" for name, identifier in sorted(auto_tools.items())]
|
||||
auto_identifiers = {identifier for identifier in auto_tools.values() if identifier}
|
||||
lines.extend(
|
||||
f"- [manual] {manual_labels.get(value, value)}"
|
||||
for value in manual_values
|
||||
if value not in auto_tools and value not in auto_identifiers
|
||||
)
|
||||
if not lines:
|
||||
click.echo("- none")
|
||||
return
|
||||
for line in lines:
|
||||
click.echo(line)
|
||||
|
||||
|
||||
def _format_tool_name_id(name: str, identifier: str | None) -> str:
|
||||
if identifier and identifier != name:
|
||||
return f"{name}: {identifier}"
|
||||
return name
|
||||
|
||||
|
||||
def _confirm_wizard_summary(
|
||||
*,
|
||||
tenant_name: str,
|
||||
app_names: list[str],
|
||||
auto_tools: WizardToolMap,
|
||||
additional_tools: WizardToolSelection,
|
||||
manual_labels: dict[str, str],
|
||||
include_referenced_tools: bool,
|
||||
include_secrets: bool,
|
||||
create_tokens: bool,
|
||||
id_strategy: str,
|
||||
conflict_strategy: str,
|
||||
output_file: str,
|
||||
) -> None:
|
||||
_print_wizard_step("Summary")
|
||||
click.echo("Migration export summary:")
|
||||
click.echo(f"source tenant: {tenant_name}")
|
||||
click.echo(f"selected apps: {len(app_names)}")
|
||||
for app_name in app_names:
|
||||
click.echo(f"- {app_name}")
|
||||
click.echo(f"auto referenced tools: {str(include_referenced_tools).lower()}")
|
||||
_print_tool_selection_body(auto_tools, additional_tools, manual_labels)
|
||||
click.echo(f"include secrets: {str(include_secrets).lower()}")
|
||||
click.echo(f"create app api token on import: {str(create_tokens).lower()}")
|
||||
click.echo(f"id strategy: {id_strategy}")
|
||||
click.echo(f"conflict strategy: {conflict_strategy}")
|
||||
click.echo(f"output path: {output_file}")
|
||||
if not click.confirm("Write migration package? [y/n, default: y]", default=True, show_default=False):
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
def _prompt_output_file() -> tuple[str, bool]:
|
||||
default_output = f"migration-data-{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
output_file = click.prompt("Output path", default=default_output, show_default=True)
|
||||
if output_file.lower() in {"y", "yes", "n", "no"}:
|
||||
raise click.ClickException("Output path must be a file path. Press Enter to use the default path.")
|
||||
overwrite = False
|
||||
if Path(output_file).exists():
|
||||
overwrite = click.confirm(
|
||||
"Output file exists. Overwrite? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
)
|
||||
if not overwrite:
|
||||
raise click.ClickException(f"Output file already exists: {output_file}")
|
||||
return output_file, overwrite
|
||||
|
||||
|
||||
def _with_output_path(context: ReportContext | None, output_path: str) -> ReportContext:
|
||||
if context is None:
|
||||
return ReportContext(output_path=output_path)
|
||||
return ReportContext(
|
||||
output_path=output_path,
|
||||
source_scope=context.source_scope,
|
||||
selected_app_count=context.selected_app_count,
|
||||
include_secrets=context.include_secrets,
|
||||
target_tenant=context.target_tenant,
|
||||
operator_email=context.operator_email,
|
||||
app_api_tokens_created=context.app_api_tokens_created,
|
||||
app_api_tokens_reused=context.app_api_tokens_reused,
|
||||
id_mapping_count=context.id_mapping_count,
|
||||
id_mappings=context.id_mappings,
|
||||
)
|
||||
|
||||
|
||||
def _render_report(report_items: list[ResourceReportItem], *, context: ReportContext | None = None) -> None:
|
||||
for line in MigrationReportService().render(report_items, context=context):
|
||||
click.echo(line)
|
||||
@ -30,7 +30,7 @@ def vdb_migrate(scope: str):
|
||||
|
||||
def migrate_annotation_vector_database():
|
||||
"""
|
||||
Migrate annotation datas to target vector database .
|
||||
Migrate annotation data to target vector database.
|
||||
"""
|
||||
click.echo(click.style("Starting annotation data migration.", fg="green"))
|
||||
create_count = 0
|
||||
@ -140,7 +140,7 @@ def migrate_annotation_vector_database():
|
||||
|
||||
def migrate_knowledge_vector_database():
|
||||
"""
|
||||
Migrate vector database datas to target vector database .
|
||||
Migrate vector database data to target vector database.
|
||||
"""
|
||||
click.echo(click.style("Starting vector database migration.", fg="green"))
|
||||
create_count = 0
|
||||
|
||||
@ -41,3 +41,21 @@ class MilvusConfig(BaseSettings):
|
||||
description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_SECURE: bool = Field(
|
||||
description="Enable TLS for the Milvus connection (one-way TLS). When True, the client uses gRPC over TLS "
|
||||
"and verifies the server certificate. Equivalent to passing secure=True to pymilvus.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
MILVUS_SERVER_PEM_PATH: str | None = Field(
|
||||
description="Filesystem path inside the container to the Milvus server certificate (PEM). Mount this via "
|
||||
"a Kubernetes secret. Used as pymilvus's server_pem_path when MILVUS_SECURE is True.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_SERVER_NAME: str | None = Field(
|
||||
description="Server name (TLS SNI / certificate CN or SAN) to verify against the Milvus server certificate. "
|
||||
"Required when MILVUS_SERVER_PEM_PATH is set.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@ -68,6 +68,7 @@ from .app import (
|
||||
workflow_app_log,
|
||||
workflow_comment,
|
||||
workflow_draft_variable,
|
||||
workflow_node_output_inspector,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
workflow_trigger,
|
||||
@ -218,6 +219,7 @@ __all__ = [
|
||||
"workflow_app_log",
|
||||
"workflow_comment",
|
||||
"workflow_draft_variable",
|
||||
"workflow_node_output_inspector",
|
||||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
"workflow_trigger",
|
||||
|
||||
@ -5,7 +5,7 @@ from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode
|
||||
from services.agent.composer_service import AgentComposerService
|
||||
from services.agent.composer_validator import ComposerConfigValidator
|
||||
from services.entities.agent_entities import ComposerSavePayload
|
||||
@ -19,7 +19,7 @@ class WorkflowAgentComposerApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, node_id: str):
|
||||
def get(self, app_model: App, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return AgentComposerService.load_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
@ -33,7 +33,7 @@ class WorkflowAgentComposerApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def put(self, app_model, node_id: str):
|
||||
def put(self, app_model: App, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_workflow_composer(
|
||||
@ -52,7 +52,7 @@ class WorkflowAgentComposerValidateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, node_id: str):
|
||||
def post(self, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
ComposerConfigValidator.validate_save_payload(payload)
|
||||
return {"result": "success", "errors": []}
|
||||
@ -64,7 +64,7 @@ class WorkflowAgentComposerCandidatesApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, node_id: str):
|
||||
def get(self, app_model: App, node_id: str):
|
||||
return AgentComposerService.get_workflow_candidates(app_id=app_model.id)
|
||||
|
||||
|
||||
@ -74,7 +74,7 @@ class WorkflowAgentComposerImpactApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, node_id: str):
|
||||
def post(self, app_model: App, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None
|
||||
@ -91,7 +91,7 @@ class WorkflowAgentComposerSaveToRosterApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, node_id: str):
|
||||
def post(self, app_model: App, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_workflow_composer(
|
||||
@ -109,7 +109,7 @@ class AgentAppComposerApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id)
|
||||
|
||||
@ -119,7 +119,7 @@ class AgentAppComposerApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model()
|
||||
def put(self, app_model):
|
||||
def put(self, app_model: App):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_agent_app_composer(
|
||||
@ -137,7 +137,7 @@ class AgentAppComposerValidateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
ComposerConfigValidator.validate_save_payload(payload)
|
||||
return {"result": "success", "errors": []}
|
||||
@ -149,5 +149,5 @@ class AgentAppComposerCandidatesApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
return AgentComposerService.get_agent_app_candidates(app_id=app_model.id)
|
||||
|
||||
@ -9,18 +9,25 @@ from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.helper import dump_response, to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from .wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
|
||||
|
||||
class ApiKeyItem(ResponseModel):
|
||||
@ -40,7 +47,7 @@ class ApiKeyList(ResponseModel):
|
||||
data: list[ApiKeyItem]
|
||||
|
||||
|
||||
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
|
||||
register_response_schema_models(console_ns, ApiKeyItem, ApiKeyList)
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
@ -64,10 +71,11 @@ class BaseApiKeyListResource(Resource):
|
||||
token_prefix: str | None = None
|
||||
max_keys = 10
|
||||
|
||||
def get(self, resource_id):
|
||||
def get(self, resource_id: str, current_tenant_id: str) -> dict[str, object]:
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(resource_id, current_tenant_id))
|
||||
|
||||
def _get_api_key_list(self, resource_id: str, current_tenant_id: str) -> ApiKeyList:
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
keys = db.session.scalars(
|
||||
@ -75,13 +83,14 @@ class BaseApiKeyListResource(Resource):
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
)
|
||||
).all()
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True)
|
||||
|
||||
@edit_permission_required
|
||||
def post(self, resource_id):
|
||||
def post(self, resource_id: str, current_tenant_id: str) -> tuple[dict[str, object], int]:
|
||||
return dump_response(ApiKeyItem, self._create_api_key(resource_id, current_tenant_id)), 201
|
||||
|
||||
def _create_api_key(self, resource_id: str, current_tenant_id: str) -> ApiToken:
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
current_key_count: int = (
|
||||
db.session.scalar(
|
||||
@ -108,7 +117,7 @@ class BaseApiKeyListResource(Resource):
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
db.session.commit()
|
||||
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
|
||||
return api_token
|
||||
|
||||
|
||||
class BaseApiKeyResource(Resource):
|
||||
@ -118,9 +127,20 @@ class BaseApiKeyResource(Resource):
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
|
||||
def delete(self, resource_id: str, api_key_id: str):
|
||||
def delete(
|
||||
self, resource_id: str, api_key_id: str, current_tenant_id: str, current_user: Account
|
||||
) -> tuple[str, int]:
|
||||
self._delete_api_key(resource_id, api_key_id, current_tenant_id, current_user)
|
||||
return "", 204
|
||||
|
||||
def _delete_api_key(
|
||||
self,
|
||||
resource_id: str,
|
||||
api_key_id: str,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
) -> None:
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
@ -147,8 +167,6 @@ class BaseApiKeyResource(Resource):
|
||||
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
|
||||
db.session.commit()
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
|
||||
class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@ -156,18 +174,21 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc(description="Get all API keys for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
|
||||
"""Get all API keys for an app"""
|
||||
return super().get(resource_id)
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
|
||||
|
||||
@console_ns.doc("create_app_api_key")
|
||||
@console_ns.doc(description="Create a new API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id: UUID):
|
||||
@with_current_tenant_id
|
||||
@edit_permission_required
|
||||
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
|
||||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
|
||||
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
@ -181,9 +202,14 @@ class AppApiKeyResource(BaseApiKeyResource):
|
||||
@console_ns.doc(description="Delete an API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id: UUID, api_key_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
|
||||
) -> tuple[str, int]:
|
||||
"""Delete an API key for an app"""
|
||||
return super().delete(str(resource_id), str(api_key_id))
|
||||
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
|
||||
return "", 204
|
||||
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
@ -196,18 +222,21 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc(description="Get all API keys for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
|
||||
"""Get all API keys for a dataset"""
|
||||
return super().get(resource_id)
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
|
||||
|
||||
@console_ns.doc("create_dataset_api_key")
|
||||
@console_ns.doc(description="Create a new API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id: UUID):
|
||||
@with_current_tenant_id
|
||||
@edit_permission_required
|
||||
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
|
||||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
|
||||
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
@ -221,9 +250,14 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
@console_ns.doc(description="Delete an API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id: UUID, api_key_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
|
||||
) -> tuple[str, int]:
|
||||
"""Delete an API key for a dataset"""
|
||||
return super().delete(str(resource_id), str(api_key_id))
|
||||
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
|
||||
return "", 204
|
||||
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
|
||||
@ -8,7 +8,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ class AgentLogApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
"""Get agent logs"""
|
||||
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
|
||||
@ -573,7 +573,7 @@ class AppApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@get_app_model(mode=None)
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
"""Get app detail"""
|
||||
app_service = AppService()
|
||||
|
||||
@ -581,7 +581,7 @@ class AppApi(Resource):
|
||||
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
|
||||
app_model.access_mode = app_setting.access_mode
|
||||
app_model.access_mode = app_setting.access_mode # type: ignore[attr-defined]
|
||||
|
||||
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
@ -598,7 +598,7 @@ class AppApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
def put(self, app_model: App):
|
||||
"""Update app"""
|
||||
args = UpdateAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
@ -627,7 +627,7 @@ class AppApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model):
|
||||
def delete(self, app_model: App):
|
||||
"""Delete app"""
|
||||
app_service = AppService()
|
||||
app_service.delete_app(app_model)
|
||||
@ -648,7 +648,7 @@ class AppCopyApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
"""Copy app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -709,7 +709,7 @@ class AppExportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
"""Export app"""
|
||||
args = AppExportQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
@ -731,7 +731,7 @@ class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
"""Publish app to Creators Platform"""
|
||||
from configs import dify_config
|
||||
from core.helper.creators import get_redirect_url, upload_dsl
|
||||
@ -762,7 +762,7 @@ class AppNameApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
args = AppNamePayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
@ -784,7 +784,7 @@ class AppIconApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
args = AppIconPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
app_service = AppService()
|
||||
@ -811,7 +811,7 @@ class AppSiteStatus(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
args = AppSiteStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
@ -833,7 +833,7 @@ class AppApiStatus(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
args = AppApiStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
@ -874,7 +874,7 @@ class AppTraceApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
# add app trace
|
||||
args = AppTracePayload.model_validate(console_ns.payload)
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ class ChatMessageAudioApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
@ -171,7 +171,7 @@ class TextModesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
try:
|
||||
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -84,7 +84,7 @@ class CompletionMessageApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
@ -131,7 +131,7 @@ class CompletionMessageStopApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model, task_id: str):
|
||||
def post(self, app_model: App, task_id: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@ -159,7 +159,7 @@ class ChatMessageApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
@ -212,7 +212,7 @@ class ChatMessageStopApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, task_id: str):
|
||||
def post(self, app_model: App, task_id: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ from fields.conversation_fields import (
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
@ -93,7 +93,7 @@ class CompletionConversationApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
@ -165,7 +165,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id: UUID):
|
||||
def get(self, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
return ConversationMessageDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
@ -182,7 +182,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def delete(self, app_model, conversation_id: UUID):
|
||||
def delete(self, app_model: App, conversation_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation_id_str = str(conversation_id)
|
||||
|
||||
@ -207,7 +207,7 @@ class ChatConversationApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
@ -318,7 +318,7 @@ class ChatConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id: UUID):
|
||||
def get(self, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
return ConversationDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
@ -335,7 +335,7 @@ class ChatConversationDetailApi(Resource):
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model, conversation_id: UUID):
|
||||
def delete(self, app_model: App, conversation_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation_id_str = str(conversation_id)
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
@ -94,7 +94,7 @@ class ConversationVariablesApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
stmt = (
|
||||
|
||||
@ -11,7 +11,7 @@ from controllers.console.app.error import (
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
@ -22,7 +22,7 @@ from core.llm_generator.llm_generator import LLMGenerator
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
@ -64,9 +64,9 @@ class RuleGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
args = RuleGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
|
||||
@ -93,9 +93,9 @@ class RuleCodeGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
@ -125,9 +125,9 @@ class RuleStructuredOutputGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
@ -157,9 +157,9 @@ class InstructionGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args.language)), None
|
||||
|
||||
@ -11,13 +11,18 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.enums import AppMCPServerStatus
|
||||
from models.model import AppMCPServer
|
||||
from models.model import App, AppMCPServer
|
||||
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
@ -73,7 +78,7 @@ class AppMCPServerController(Resource):
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
|
||||
if server is None:
|
||||
return {}
|
||||
@ -92,8 +97,8 @@ class AppMCPServerController(Resource):
|
||||
@login_required
|
||||
@setup_required
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, app_model: App):
|
||||
payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
description = payload.description
|
||||
@ -127,7 +132,7 @@ class AppMCPServerController(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
def put(self, app_model: App):
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
server = db.session.get(AppMCPServer, payload.id)
|
||||
if not server:
|
||||
@ -163,8 +168,8 @@ class AppMCPServerRefreshController(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, server_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, server_id: UUID):
|
||||
server = db.session.scalar(
|
||||
select(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
|
||||
|
||||
@ -45,7 +45,7 @@ from libs.helper import to_timestamp, uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from models.model import App, AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService, attach_message_extra_contents
|
||||
@ -180,7 +180,7 @@ class ChatMessageListApi(Resource):
|
||||
@setup_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
|
||||
conversation = db.session.scalar(
|
||||
@ -257,7 +257,7 @@ class MessageFeedbackApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = MessageFeedbackPayload.model_validate(console_ns.payload)
|
||||
@ -314,7 +314,7 @@ class MessageAnnotationCountApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
count = db.session.scalar(
|
||||
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
|
||||
)
|
||||
@ -337,7 +337,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, message_id: UUID):
|
||||
def get(self, app_model: App, message_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
message_id_str = str(message_id)
|
||||
|
||||
@ -379,7 +379,7 @@ class MessageFeedbackExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Import the service function
|
||||
@ -417,7 +417,7 @@ class MessageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model, message_id: UUID):
|
||||
def get(self, app_model: App, message_id: UUID):
|
||||
message_id_str = str(message_id)
|
||||
|
||||
message = db.session.scalar(
|
||||
|
||||
@ -16,7 +16,7 @@ from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, AppModelConfig
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ class ModelConfigResource(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
"""Modify app model config"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# validate config
|
||||
|
||||
@ -20,6 +20,7 @@ from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Site
|
||||
from models.model import App
|
||||
|
||||
|
||||
class AppSiteUpdatePayload(BaseModel):
|
||||
@ -84,7 +85,7 @@ class AppSite(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
@ -133,7 +134,7 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode
|
||||
from models.model import App
|
||||
|
||||
|
||||
class StatisticTimeRangeQuery(BaseModel):
|
||||
@ -47,7 +48,7 @@ class DailyMessageStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -61,8 +62,12 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -104,7 +109,7 @@ class DailyConversationStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -118,8 +123,12 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -160,7 +169,7 @@ class DailyTerminalsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -174,8 +183,12 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -217,7 +230,7 @@ class DailyTokenCostStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -232,8 +245,12 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -277,7 +294,7 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -299,8 +316,12 @@ FROM
|
||||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -353,7 +374,7 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -371,8 +392,12 @@ LEFT JOIN
|
||||
WHERE
|
||||
m.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -419,7 +444,7 @@ class AverageResponseTimeStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -433,8 +458,12 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -476,7 +505,7 @@ class TokensPerSecondStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
@ -492,8 +521,12 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
|
||||
@ -83,13 +83,14 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
# Refresh the url signature before returning it to client.
|
||||
if isinstance(value, FileSegment):
|
||||
file = value.value
|
||||
file.remote_url = file.generate_url()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files = value.value
|
||||
for file in files:
|
||||
match value:
|
||||
case FileSegment():
|
||||
file = value.value
|
||||
file.remote_url = file.generate_url()
|
||||
case ArrayFileSegment():
|
||||
files = value.value
|
||||
for file in files:
|
||||
file.remote_url = file.generate_url()
|
||||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
|
||||
415
api/controllers/console/app/workflow_node_output_inspector.py
Normal file
415
api/controllers/console/app/workflow_node_output_inspector.py
Normal file
@ -0,0 +1,415 @@
|
||||
"""Console REST endpoints for the Node Output Inspector (Stage 4 §8 / §10.3).
|
||||
|
||||
PRD §Node Output Inspector replaces the consumer-organized Variable Inspector
|
||||
with a producer-organized view of each node's declared outputs and their
|
||||
per-run status. This module exposes two parallel sets of three read-only
|
||||
endpoints — one for ``/workflows/draft/runs/...`` (Composer test runs) and one
|
||||
for ``/workflows/published/runs/...`` (real App API / webapp / webhook /
|
||||
schedule / plugin triggers). Both sets share the same service code, the same
|
||||
response shapes, and the same error codes; the URL is the *only* difference,
|
||||
so the frontend can pick the right prefix based on which run-detail page the
|
||||
user is on.
|
||||
|
||||
Decision D-1 (published Inspector deferred) was lifted 2026-05-26 — the
|
||||
``published_run_inspector_not_implemented`` 404 code is therefore no longer
|
||||
produced.
|
||||
|
||||
URLs follow the design doc and reuse the existing
|
||||
``/apps/<uuid:app_id>/workflows/draft/...`` prefix from
|
||||
:mod:`controllers.console.app.workflow_draft_variable`. The
|
||||
``published`` prefix mirrors it shape-for-shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.exception import BaseHTTPException
|
||||
from libs.login import login_required
|
||||
from models import App, AppMode
|
||||
from services.workflow import inspector_events
|
||||
from services.workflow.node_output_inspector_service import (
|
||||
NodeOutputInspectorError,
|
||||
NodeOutputInspectorService,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Heartbeat cadence — every N empty subscribe ticks emit a SSE comment so
|
||||
# intervening proxies (nginx, ingress) don't reap the idle connection.
|
||||
# ``inspector_events.subscribe`` ticks at 1s, so 15 → 15s heartbeat.
|
||||
_HEARTBEAT_EVERY_TICKS = 15
|
||||
# Hard ceiling on a single stream — if we never see a terminal workflow
|
||||
# event (engine crashed, redis dropped the message), force-close after this
|
||||
# many ticks (= seconds).
|
||||
_STREAM_HARD_TIMEOUT_TICKS = 1800 # 30 min
|
||||
|
||||
|
||||
def _service() -> NodeOutputInspectorService:
|
||||
"""One-line factory so tests can monkeypatch a stub if needed."""
|
||||
return NodeOutputInspectorService()
|
||||
|
||||
|
||||
def _serve_snapshot(app_model: App, run_id: UUID) -> dict:
|
||||
"""Resource-body shared by draft + published snapshot endpoints.
|
||||
|
||||
Pulled out so the 6 REST routes don't duplicate the same 6-line try/except
|
||||
+ ``model_dump`` ritual — the routes shrink to one-liners and the actual
|
||||
behaviour lives here, where unit tests can hit it without spinning up
|
||||
Flask request context.
|
||||
"""
|
||||
try:
|
||||
snapshot = _service().snapshot_workflow_run(app_model=app_model, workflow_run_id=str(run_id))
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
return snapshot.model_dump(mode="json")
|
||||
|
||||
|
||||
def _serve_node_detail(app_model: App, run_id: UUID, node_id: str) -> dict:
|
||||
"""Resource-body shared by draft + published node-detail endpoints."""
|
||||
try:
|
||||
view = _service().node_detail(
|
||||
app_model=app_model,
|
||||
workflow_run_id=str(run_id),
|
||||
node_id=node_id,
|
||||
)
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
return view.model_dump(mode="json")
|
||||
|
||||
|
||||
def _serve_output_preview(app_model: App, run_id: UUID, node_id: str, output_name: str) -> dict:
|
||||
"""Resource-body shared by draft + published output-preview endpoints."""
|
||||
try:
|
||||
preview = _service().output_preview(
|
||||
app_model=app_model,
|
||||
workflow_run_id=str(run_id),
|
||||
node_id=node_id,
|
||||
output_name=output_name,
|
||||
)
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
return preview.model_dump(mode="json")
|
||||
|
||||
|
||||
class _InspectorNotFound(BaseHTTPException):
|
||||
"""404 that preserves the inspector's specific error code.
|
||||
|
||||
Without this the response body collapses to a generic ``not_found`` code
|
||||
and clients lose the ability to distinguish, e.g.,
|
||||
``workflow_run_not_found`` from ``published_run_inspector_not_implemented``.
|
||||
"""
|
||||
|
||||
code = 404
|
||||
|
||||
def __init__(self, error: NodeOutputInspectorError) -> None:
|
||||
self.error_code = error.code
|
||||
super().__init__(description=str(error))
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs")
|
||||
class WorkflowDraftRunNodeOutputsApi(Resource):
|
||||
"""Whole-run snapshot organized by producer node."""
|
||||
|
||||
@console_ns.doc("get_workflow_draft_run_node_outputs")
|
||||
@console_ns.doc(description="Snapshot of every node's declared outputs for a draft workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return _serve_snapshot(app_model, run_id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/<string:node_id>")
|
||||
class WorkflowDraftRunNodeOutputDetailApi(Resource):
|
||||
"""One node's declared outputs + per-output status."""
|
||||
|
||||
@console_ns.doc("get_workflow_draft_run_node_output_detail")
|
||||
@console_ns.doc(description="One node's declared outputs for a draft workflow run.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str):
|
||||
return _serve_node_detail(app_model, run_id, node_id)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/<string:node_id>/<string:output_name>/preview"
|
||||
)
|
||||
class WorkflowDraftRunNodeOutputPreviewApi(Resource):
|
||||
"""Full value for one declared output (with signed URL for file refs)."""
|
||||
|
||||
@console_ns.doc("get_workflow_draft_run_node_output_preview")
|
||||
@console_ns.doc(description="Full value for one declared output, including signed download URL for files.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
"output_name": "Declared output name as exposed by Composer",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node / output not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str):
|
||||
return _serve_output_preview(app_model, run_id, node_id, output_name)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# SSE event stream — shared generator used by draft + published variants
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _sse_envelope(event: str, data: dict | str, event_id: int) -> str:
|
||||
"""Format one SSE record per D-5 ``{event, data, id}`` envelope.
|
||||
|
||||
``data`` is JSON-serialized when given as a dict; raw strings are
|
||||
forwarded unchanged so we can also emit ``:keepalive`` comment lines.
|
||||
"""
|
||||
payload = data if isinstance(data, str) else json.dumps(data, ensure_ascii=False)
|
||||
return f"event: {event}\nid: {event_id}\ndata: {payload}\n\n"
|
||||
|
||||
|
||||
def _stream_inspector_events(app_model: App, run_id: UUID) -> Iterator[str]:
|
||||
"""Yield SSE-framed strings for one workflow run.
|
||||
|
||||
The stream begins with a full ``snapshot`` event so the client has a
|
||||
starting state without needing a separate REST GET. Then for every
|
||||
``node_changed`` message from the pub/sub channel we re-read that node
|
||||
from DB and push a fresh ``node_changed`` event. When the workflow run
|
||||
reaches a terminal state we push one final ``workflow_run_completed``
|
||||
event and close the stream.
|
||||
|
||||
Failures inside the loop are caught and surfaced as ``error`` events so
|
||||
the frontend can show a banner rather than seeing the connection drop
|
||||
silently. The Inspector never raises across the SSE boundary.
|
||||
"""
|
||||
service = _service()
|
||||
run_id_str = str(run_id)
|
||||
|
||||
# Initial snapshot — also flushes a 404 back at the client right away
|
||||
# if the run is gone (raised before yielding any bytes, so Flask turns it
|
||||
# into the normal HTTP 404 path).
|
||||
try:
|
||||
snapshot = service.snapshot_workflow_run(app_model=app_model, workflow_run_id=run_id_str)
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
|
||||
event_id = 0
|
||||
yield _sse_envelope("snapshot", snapshot.model_dump(mode="json"), event_id)
|
||||
|
||||
# If the run already finished by the time the client connected, emit
|
||||
# the terminal envelope synchronously and close — no point subscribing.
|
||||
# The enum value for partial success is the hyphenated ``partial-succeeded``
|
||||
# (graphon.enums.WorkflowExecutionStatus), not ``partial_succeeded``.
|
||||
if snapshot.workflow_run_status.value in {"succeeded", "failed", "stopped", "partial-succeeded"}:
|
||||
event_id += 1
|
||||
yield _sse_envelope(
|
||||
"workflow_run_completed",
|
||||
{"workflow_run_id": run_id_str, "workflow_run_status": snapshot.workflow_run_status.value},
|
||||
event_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Live subscription
|
||||
ticks_since_heartbeat = 0
|
||||
total_ticks = 0
|
||||
for message in inspector_events.subscribe(run_id_str, timeout_seconds=1.0):
|
||||
total_ticks += 1
|
||||
if total_ticks > _STREAM_HARD_TIMEOUT_TICKS:
|
||||
logger.warning(
|
||||
"Inspector SSE: forcing close after %ds without terminal event for run %s",
|
||||
_STREAM_HARD_TIMEOUT_TICKS,
|
||||
run_id_str,
|
||||
)
|
||||
return
|
||||
|
||||
# Heartbeat sentinel — ``inspector_events.subscribe`` synthesizes a
|
||||
# ``node_changed`` message with both fields ``None`` on every redis
|
||||
# timeout. Real ``workflow_completed`` messages keep their kind even
|
||||
# when status couldn't be resolved (publisher race), so checking kind
|
||||
# first makes the heartbeat branch safe.
|
||||
if message.kind == "node_changed" and message.node_id is None and message.status is None:
|
||||
ticks_since_heartbeat += 1
|
||||
if ticks_since_heartbeat >= _HEARTBEAT_EVERY_TICKS:
|
||||
yield ":keepalive\n\n"
|
||||
ticks_since_heartbeat = 0
|
||||
continue
|
||||
ticks_since_heartbeat = 0
|
||||
|
||||
if message.kind == "workflow_completed":
|
||||
event_id += 1
|
||||
yield _sse_envelope(
|
||||
"workflow_run_completed",
|
||||
{"workflow_run_id": run_id_str, "workflow_run_status": message.status or "unknown"},
|
||||
event_id,
|
||||
)
|
||||
return
|
||||
|
||||
# node_changed: recompute the node slice from DB
|
||||
if not message.node_id:
|
||||
continue
|
||||
try:
|
||||
node_view = service.node_detail(
|
||||
app_model=app_model,
|
||||
workflow_run_id=run_id_str,
|
||||
node_id=message.node_id,
|
||||
)
|
||||
except NodeOutputInspectorError:
|
||||
# Node may not appear in the graph yet (race with persistence); skip.
|
||||
continue
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Inspector SSE: node_detail failed for run %s node %s",
|
||||
run_id_str,
|
||||
message.node_id,
|
||||
exc_info=True,
|
||||
)
|
||||
event_id += 1
|
||||
yield _sse_envelope(
|
||||
"error",
|
||||
{"node_id": message.node_id, "message": "failed to refresh node detail"},
|
||||
event_id,
|
||||
)
|
||||
continue
|
||||
|
||||
event_id += 1
|
||||
yield _sse_envelope("node_changed", node_view.model_dump(mode="json"), event_id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/events")
|
||||
class WorkflowDraftRunNodeOutputEventsApi(Resource):
|
||||
"""SSE stream of inspector deltas for a draft run."""
|
||||
|
||||
@console_ns.doc("stream_workflow_draft_run_node_output_events")
|
||||
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a draft workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return Response(
|
||||
_stream_inspector_events(app_model, run_id),
|
||||
mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Published-run endpoints — symmetric to the draft trio above
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs")
|
||||
class WorkflowPublishedRunNodeOutputsApi(Resource):
|
||||
"""Whole-run snapshot for a *published* workflow run.
|
||||
|
||||
Same response shape as the ``/draft/`` variant — frontend can multiplex
|
||||
based on which page (Composer test-run vs. Run History) is mounted.
|
||||
"""
|
||||
|
||||
@console_ns.doc("get_workflow_published_run_node_outputs")
|
||||
@console_ns.doc(description="Snapshot of every node's declared outputs for a published workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return _serve_snapshot(app_model, run_id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs/<string:node_id>")
|
||||
class WorkflowPublishedRunNodeOutputDetailApi(Resource):
|
||||
"""One node's declared outputs + per-output status (published run)."""
|
||||
|
||||
@console_ns.doc("get_workflow_published_run_node_output_detail")
|
||||
@console_ns.doc(description="One node's declared outputs for a published workflow run.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str):
|
||||
return _serve_node_detail(app_model, run_id, node_id)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>"
|
||||
"/node-outputs/<string:node_id>/<string:output_name>/preview"
|
||||
)
|
||||
class WorkflowPublishedRunNodeOutputPreviewApi(Resource):
|
||||
"""Full value for one declared output of a published run."""
|
||||
|
||||
@console_ns.doc("get_workflow_published_run_node_output_preview")
|
||||
@console_ns.doc(description="Full value for one declared output of a published run.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
"output_name": "Declared output name as exposed by Composer",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node / output not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str):
|
||||
return _serve_output_preview(app_model, run_id, node_id, output_name)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs/events")
|
||||
class WorkflowPublishedRunNodeOutputEventsApi(Resource):
|
||||
"""SSE stream of inspector deltas for a published run."""
|
||||
|
||||
@console_ns.doc("stream_workflow_published_run_node_output_events")
|
||||
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a published workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return Response(
|
||||
_stream_inspector_events(app_model, run_id),
|
||||
mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
@ -11,7 +11,7 @@ from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -86,7 +86,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -126,7 +126,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -166,7 +166,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
@ -5,12 +5,12 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from .. import console_ns
|
||||
from ..auth.error import ApiKeyAuthFailedError
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required, with_current_tenant_id
|
||||
|
||||
|
||||
class ApiKeyAuthBindingPayload(BaseModel):
|
||||
@ -42,8 +42,8 @@ class ApiKeyAuthDataSource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
@ -69,9 +69,9 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
|
||||
data = payload.model_dump()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(data)
|
||||
@ -89,10 +89,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.response(204, "Binding deleted successfully")
|
||||
def delete(self, binding_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, binding_id: UUID):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id))
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -8,9 +8,9 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||
@ -133,12 +133,10 @@ class OAuthServerUserAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
user_account_id = account.id
|
||||
|
||||
def post(self, oauth_provider_app: OAuthProviderApp, current_user: Account):
|
||||
user_account_id = current_user.id
|
||||
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
||||
@ -5,7 +5,7 @@ from uuid import UUID
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import String, cast, func, or_, select
|
||||
from sqlalchemy import String, case, cast, func, literal, or_, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@ -169,9 +169,17 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
# Use database-specific methods for JSON array search
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
|
||||
# Feed the set-returning function a JSON array in every row. Filtering in
|
||||
# the subquery is not enough because PostgreSQL can still evaluate the
|
||||
# SRF on scalar JSON before applying the predicate.
|
||||
keywords_jsonb = cast(DocumentSegment.keywords, JSONB)
|
||||
keywords_array = case(
|
||||
(func.jsonb_typeof(keywords_jsonb) == "array", keywords_jsonb),
|
||||
else_=cast(literal("[]"), JSONB),
|
||||
)
|
||||
keywords_condition = func.array_to_string(
|
||||
func.array(
|
||||
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
|
||||
select(func.jsonb_array_elements_text(keywords_array))
|
||||
.correlate(DocumentSegment)
|
||||
.scalar_subquery()
|
||||
),
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from fields.hit_testing_fields import HitTestingResponse
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
@ -20,86 +17,8 @@ from ..wraps import (
|
||||
setup_required,
|
||||
)
|
||||
|
||||
|
||||
class HitTestingDocument(ResponseModel):
|
||||
id: str | None = None
|
||||
data_source_type: str | None = None
|
||||
name: str | None = None
|
||||
doc_type: str | None = None
|
||||
doc_metadata: Any | None = None
|
||||
|
||||
|
||||
class HitTestingSegment(ResponseModel):
|
||||
id: str | None = None
|
||||
position: int | None = None
|
||||
document_id: str | None = None
|
||||
content: str | None = None
|
||||
sign_content: str | None = None
|
||||
answer: str | None = None
|
||||
word_count: int | None = None
|
||||
tokens: int | None = None
|
||||
keywords: list[str] = Field(default_factory=list)
|
||||
index_node_id: str | None = None
|
||||
index_node_hash: str | None = None
|
||||
hit_count: int | None = None
|
||||
enabled: bool | None = None
|
||||
disabled_at: int | None = None
|
||||
disabled_by: str | None = None
|
||||
status: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
indexing_at: int | None = None
|
||||
completed_at: int | None = None
|
||||
error: str | None = None
|
||||
stopped_at: int | None = None
|
||||
document: HitTestingDocument | None = None
|
||||
|
||||
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class HitTestingChildChunk(ResponseModel):
|
||||
id: str | None = None
|
||||
content: str | None = None
|
||||
position: int | None = None
|
||||
score: float | None = None
|
||||
|
||||
|
||||
class HitTestingFile(ResponseModel):
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
size: int | None = None
|
||||
extension: str | None = None
|
||||
mime_type: str | None = None
|
||||
source_url: str | None = None
|
||||
|
||||
|
||||
class HitTestingRecord(ResponseModel):
|
||||
segment: HitTestingSegment | None = None
|
||||
child_chunks: list[HitTestingChildChunk] = Field(default_factory=list)
|
||||
score: float | None = None
|
||||
tsne_position: Any | None = None
|
||||
files: list[HitTestingFile] = Field(default_factory=list)
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class HitTestingResponse(ResponseModel):
|
||||
query: str
|
||||
records: list[HitTestingRecord] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
HitTestingPayload,
|
||||
HitTestingDocument,
|
||||
HitTestingSegment,
|
||||
HitTestingChildChunk,
|
||||
HitTestingFile,
|
||||
HitTestingRecord,
|
||||
HitTestingResponse,
|
||||
)
|
||||
register_schema_models(console_ns, HitTestingPayload)
|
||||
register_response_schema_models(console_ns, HitTestingResponse)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
@ -119,12 +38,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id: UUID) -> dict[str, object]:
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
payload = HitTestingPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
args = self.parse_args(console_ns.payload)
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json")
|
||||
return dump_response(HitTestingResponse, self.perform_hit_testing(dataset, args))
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
@ -19,10 +18,10 @@ from core.errors.error import (
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.hit_testing_service import HitTestingService
|
||||
@ -38,16 +37,6 @@ class HitTestingPayload(BaseModel):
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def _extract_hit_testing_query(query: Any) -> str:
|
||||
"""Return the query string from the service response shape."""
|
||||
if isinstance(query, dict):
|
||||
content = query.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
@staticmethod
|
||||
def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Ensure collection fields match the API schema before response validation."""
|
||||
@ -63,6 +52,7 @@ class DatasetsHitTestingBase:
|
||||
segment = normalized_record.get("segment")
|
||||
if isinstance(segment, dict):
|
||||
normalized_segment = dict(segment)
|
||||
normalized_segment.setdefault("sign_content", None)
|
||||
if normalized_segment.get("keywords") is None:
|
||||
normalized_segment["keywords"] = []
|
||||
normalized_record["segment"] = normalized_segment
|
||||
@ -73,12 +63,15 @@ class DatasetsHitTestingBase:
|
||||
if normalized_record.get("files") is None:
|
||||
normalized_record["files"] = []
|
||||
|
||||
normalized_record.setdefault("tsne_position", None)
|
||||
normalized_record.setdefault("summary", None)
|
||||
|
||||
normalized_records.append(normalized_record)
|
||||
|
||||
return normalized_records
|
||||
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
def get_and_validate_dataset(dataset_id: str) -> Dataset:
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
@ -92,33 +85,35 @@ class DatasetsHitTestingBase:
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def hit_testing_args_check(args: dict[str, Any]):
|
||||
def hit_testing_args_check(args: dict[str, Any]) -> None:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
@staticmethod
|
||||
def parse_args(payload: dict[str, Any]) -> dict[str, Any]:
|
||||
def parse_args(payload: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Validate and return hit-testing arguments from an incoming payload."""
|
||||
hit_testing_payload = HitTestingPayload.model_validate(payload or {})
|
||||
return hit_testing_payload.model_dump(exclude_none=True)
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
def perform_hit_testing(dataset: Dataset, args: dict[str, Any]) -> dict[str, Any]:
|
||||
assert isinstance(current_user, Account)
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args.get("query"),
|
||||
query=cast(str, args.get("query")),
|
||||
account=current_user,
|
||||
retrieval_model=args.get("retrieval_model"),
|
||||
external_retrieval_model=args.get("external_retrieval_model"),
|
||||
external_retrieval_model=cast(dict[str, Any], args.get("external_retrieval_model")),
|
||||
attachment_ids=args.get("attachment_ids"),
|
||||
limit=10,
|
||||
)
|
||||
query = response.get("query")
|
||||
if not isinstance(query, dict) or not isinstance(query.get("content"), str):
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
return {
|
||||
"query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._prepare_hit_testing_records(
|
||||
marshal(response.get("records", []), hit_testing_record_fields)
|
||||
),
|
||||
"query": {"content": query["content"]},
|
||||
"records": DatasetsHitTestingBase._prepare_hit_testing_records(response.get("records", [])),
|
||||
}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
|
||||
@ -20,6 +20,7 @@ from controllers.console.app.error import (
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import InstalledApp
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -40,8 +41,10 @@ register_schema_model(console_ns, TextToAudioPayload)
|
||||
endpoint="installed_app_audio",
|
||||
)
|
||||
class ChatAudioApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
def post(self, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files["file"]
|
||||
|
||||
@ -81,8 +84,10 @@ class ChatAudioApi(InstalledAppResource):
|
||||
)
|
||||
class ChatTextApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[TextToAudioPayload.__name__])
|
||||
def post(self, installed_app):
|
||||
def post(self, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
try:
|
||||
payload = TextToAudioPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, InstalledApp
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -83,8 +83,10 @@ register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
)
|
||||
class CompletionApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
|
||||
def post(self, installed_app):
|
||||
def post(self, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -133,8 +135,10 @@ class CompletionApi(InstalledAppResource):
|
||||
)
|
||||
class CompletionStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app, task_id: str):
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -157,8 +161,10 @@ class CompletionStopApi(InstalledAppResource):
|
||||
)
|
||||
class ChatApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
|
||||
def post(self, installed_app):
|
||||
def post(self, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -209,8 +215,10 @@ class ChatApi(InstalledAppResource):
|
||||
)
|
||||
class ChatStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app, task_id: str):
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -8,6 +8,7 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -20,7 +21,7 @@ from fields.conversation_fields import (
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, InstalledApp
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
@ -44,8 +45,10 @@ register_response_schema_models(console_ns, ResultResponse)
|
||||
)
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
def get(self, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -92,8 +95,10 @@ class ConversationListApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationApi(InstalledAppResource):
|
||||
@console_ns.response(204, "Conversation deleted successfully")
|
||||
def delete(self, installed_app, c_id: UUID):
|
||||
def delete(self, installed_app: InstalledApp, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -115,8 +120,10 @@ class ConversationApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
|
||||
def post(self, installed_app, c_id: UUID):
|
||||
def post(self, installed_app: InstalledApp, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -146,8 +153,10 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationPinApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def patch(self, installed_app, c_id: UUID):
|
||||
def patch(self, installed_app: InstalledApp, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -170,8 +179,10 @@ class ConversationPinApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationUnPinApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def patch(self, installed_app, c_id: UUID):
|
||||
def patch(self, installed_app: InstalledApp, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -262,7 +262,7 @@ class InstalledAppApi(InstalledAppResource):
|
||||
"""
|
||||
|
||||
@console_ns.response(204, "App uninstalled successfully")
|
||||
def delete(self, installed_app):
|
||||
def delete(self, installed_app: InstalledApp):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
if installed_app.app_owner_tenant_id == current_tenant_id:
|
||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||
@ -273,7 +273,7 @@ class InstalledAppApi(InstalledAppResource):
|
||||
return "", 204
|
||||
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultMessageResponse.__name__])
|
||||
def patch(self, installed_app):
|
||||
def patch(self, installed_app: InstalledApp):
|
||||
payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
commit_args = False
|
||||
|
||||
@ -10,6 +10,7 @@ from controllers.common.controller_schemas import MessageFeedbackPayload, Messag
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
@ -30,7 +31,7 @@ from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models import Account
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, InstalledApp
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -61,8 +62,10 @@ register_response_schema_models(console_ns, ResultResponse, SuggestedQuestionsRe
|
||||
class MessageListApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app):
|
||||
def get(self, current_user: Account, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -98,8 +101,10 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||
@console_ns.response(200, "Feedback submitted successfully", console_ns.models[ResultResponse.__name__])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app, message_id: UUID):
|
||||
def post(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
|
||||
@ -126,8 +131,10 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app, message_id: UUID):
|
||||
def get(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -172,8 +179,10 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SuggestedQuestionsResponse.__name__])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app, message_id: UUID):
|
||||
def get(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -7,12 +7,14 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from models import Account
|
||||
from models.model import InstalledApp
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
@ -24,8 +26,10 @@ register_response_schema_models(console_ns, ResultResponse)
|
||||
class SavedMessageListApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app):
|
||||
def get(self, current_user: Account, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -48,8 +52,10 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app):
|
||||
def post(self, current_user: Account, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -69,8 +75,10 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
class SavedMessageApi(InstalledAppResource):
|
||||
@console_ns.response(204, "Saved message deleted successfully")
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, installed_app, message_id: UUID):
|
||||
def delete(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ from controllers.console.wraps import (
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse, UploadConfig
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
from . import console_ns
|
||||
|
||||
@ -18,7 +18,7 @@ from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
|
||||
@ -9,9 +9,16 @@ from werkzeug.exceptions import Forbidden
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.enums import TagType
|
||||
from services.tag_service import (
|
||||
SaveTagPayload,
|
||||
@ -92,8 +99,8 @@ class TagListApi(Resource):
|
||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||
)
|
||||
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
raw_args = request.args.to_dict()
|
||||
param = TagListQueryParam.model_validate(raw_args)
|
||||
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
|
||||
@ -109,9 +116,9 @@ class TagListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
# Allow users with edit permission, or dataset editors (including dataset operators).
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
@ -132,8 +139,8 @@ class TagUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def patch(self, current_user: Account, tag_id: UUID):
|
||||
tag_id_str = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -163,20 +170,19 @@ class TagUpdateDeleteApi(Resource):
|
||||
return "", 204
|
||||
|
||||
|
||||
def _require_tag_binding_edit_permission() -> None:
|
||||
def _require_tag_binding_edit_permission(current_user: Account) -> None:
|
||||
"""
|
||||
Ensure the current account can edit tag bindings.
|
||||
|
||||
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission(current_user)
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
@ -189,8 +195,8 @@ def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission(current_user)
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
@ -213,8 +219,9 @@ class TagBindingCollectionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_tag_bindings()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
return _create_tag_bindings(current_user)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
@ -228,5 +235,6 @@ class TagBindingRemoveApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _remove_tag_bindings()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
return _remove_tag_bindings(current_user)
|
||||
|
||||
@ -77,7 +77,7 @@ register_response_schema_models(console_ns, SimpleResultDataResponse, Verificati
|
||||
def _is_role_enabled(role: TenantAccountRole | str, tenant_id: str) -> bool:
|
||||
if role != TenantAccountRole.DATASET_OPERATOR:
|
||||
return True
|
||||
return FeatureService.get_features(tenant_id=tenant_id).dataset_operator_enabled
|
||||
return FeatureService.get_features(tenant_id=tenant_id, exclude_vector_space=True).dataset_operator_enabled
|
||||
|
||||
|
||||
def _normalize_invitee_emails(emails: list[str]) -> list[str]:
|
||||
@ -113,7 +113,7 @@ def _check_member_invite_limits(tenant_id: str, new_member_count: int) -> None:
|
||||
if new_member_count <= 0:
|
||||
return
|
||||
|
||||
features = FeatureService.get_features(tenant_id=tenant_id)
|
||||
features = FeatureService.get_features(tenant_id=tenant_id, exclude_vector_space=True)
|
||||
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
workspace_members = features.workspace_members
|
||||
|
||||
@ -8,12 +8,17 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
@ -138,9 +143,8 @@ class DefaultModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
args = ParserGetDefault.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -156,9 +160,8 @@ class DefaultModelApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserPostDefault.model_validate(console_ns.payload)
|
||||
model_provider_service = ModelProviderService()
|
||||
model_settings = args.model_settings
|
||||
@ -189,9 +192,8 @@ class ModelProviderModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider):
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
|
||||
|
||||
@ -202,9 +204,9 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
# To save the model's load balance configs
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = ParserPostModels.model_validate(console_ns.payload)
|
||||
|
||||
if args.config_from == "custom-model":
|
||||
@ -249,9 +251,8 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def delete(self, tenant_id: str, provider: str):
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -268,9 +269,8 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -323,9 +323,8 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
args = ParserCreateCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -355,8 +354,8 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, provider: str):
|
||||
args = ParserUpdateCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -382,8 +381,8 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, provider: str):
|
||||
args = ParserDeleteCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -406,8 +405,8 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
args = ParserSwitch.model_validate(console_ns.payload)
|
||||
|
||||
service = ModelProviderService()
|
||||
@ -430,9 +429,8 @@ class ModelProviderModelEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, provider: str):
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -452,9 +450,8 @@ class ModelProviderModelDisableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, provider: str):
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -480,8 +477,8 @@ class ModelProviderModelValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
args = ParserValidate.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -515,9 +512,9 @@ class ModelProviderModelParameterRuleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
args = ParserParameter.model_validate(request.args.to_dict(flat=True))
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
parameter_rules = model_provider_service.get_model_parameter_rules(
|
||||
@ -532,8 +529,8 @@ class ModelProviderAvailableModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, model_type: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, model_type: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
||||
|
||||
|
||||
@ -166,10 +166,10 @@ class TenantListApi(Resource):
|
||||
if tenant_plan:
|
||||
plan = tenant_plan["plan"] or CloudPlan.SANDBOX
|
||||
else:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
features = FeatureService.get_features(tenant.id, exclude_vector_space=True)
|
||||
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
|
||||
elif not is_enterprise_only:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
features = FeatureService.get_features(tenant.id, exclude_vector_space=True)
|
||||
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
|
||||
|
||||
# Create a dictionary with tenant attributes
|
||||
|
||||
@ -96,21 +96,28 @@ def cloud_edition_billing_resource_check[**P, R](resource: str) -> Callable[[Cal
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if resource == "vector_space":
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
vector_space = FeatureService.get_vector_space(current_tenant_id)
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
abort(
|
||||
403,
|
||||
"The capacity of the knowledge storage space has reached the limit of your subscription.",
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
features = FeatureService.get_features(current_tenant_id, exclude_vector_space=True)
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
annotation_quota_limit = features.annotation_quota_limit
|
||||
if resource == "members" and 0 < members.limit <= members.size:
|
||||
abort(403, "The number of members has reached the limit of your subscription.")
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
abort(403, "The number of apps has reached the limit of your subscription.")
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
abort(
|
||||
403, "The capacity of the knowledge storage space has reached the limit of your subscription."
|
||||
)
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
# The api of file upload is used in the multiple places,
|
||||
# so we need to check the source of the request from datasets
|
||||
@ -140,7 +147,7 @@ def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
features = FeatureService.get_features(current_tenant_id, exclude_vector_space=True)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
@ -291,7 +298,7 @@ def knowledge_pipeline_publish_enabled[**P, R](view: Callable[P, R]) -> Callable
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
features = FeatureService.get_features(current_tenant_id, exclude_vector_space=True)
|
||||
if features.knowledge_pipeline.publish_enabled:
|
||||
return view(*args, **kwargs)
|
||||
abort(403)
|
||||
|
||||
@ -37,6 +37,13 @@ from controllers.openapi._models import (
|
||||
DeviceMutateRequest,
|
||||
DeviceMutateResponse,
|
||||
DevicePollRequest,
|
||||
MemberActionResponse,
|
||||
MemberInvitePayload,
|
||||
MemberInviteResponse,
|
||||
MemberListQuery,
|
||||
MemberListResponse,
|
||||
MemberResponse,
|
||||
MemberRoleUpdatePayload,
|
||||
MessageMetadata,
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
@ -63,6 +70,9 @@ register_schema_models(
|
||||
DevicePollRequest,
|
||||
DeviceLookupQuery,
|
||||
DeviceMutateRequest,
|
||||
MemberInvitePayload,
|
||||
MemberListQuery,
|
||||
MemberRoleUpdatePayload,
|
||||
PermittedExternalAppsListQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
@ -86,6 +96,10 @@ register_response_schema_models(
|
||||
WorkspaceSummaryResponse,
|
||||
WorkspaceListResponse,
|
||||
WorkspaceDetailResponse,
|
||||
MemberResponse,
|
||||
MemberListResponse,
|
||||
MemberInviteResponse,
|
||||
MemberActionResponse,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateResponse,
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from libs.helper import UUIDStrOrEmpty, uuid_value
|
||||
from libs.helper import EmailStr, UUIDStrOrEmpty, uuid_value
|
||||
from models.model import AppMode
|
||||
|
||||
# Server-side cap on `limit` query param for /openapi/v1/* list endpoints.
|
||||
@ -342,3 +342,61 @@ class ApprovalGrantClaimsPayload(BaseModel):
|
||||
user_code: str = Field(min_length=1, max_length=32)
|
||||
nonce: str = Field(min_length=1, max_length=128)
|
||||
csrf_token: str = Field(min_length=1, max_length=128)
|
||||
|
||||
|
||||
# Closed enum for invite/update-role payloads. Owner is intentionally not
|
||||
# assignable through these endpoints — ownership transfer goes through the
|
||||
# console's three-step email-verification flow.
|
||||
MemberAssignableRole = Literal["normal", "admin"]
|
||||
|
||||
|
||||
class MemberResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
role: str
|
||||
status: str
|
||||
avatar: str | None = None
|
||||
|
||||
|
||||
class MemberListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[MemberResponse]
|
||||
|
||||
|
||||
class MemberListQuery(BaseModel):
|
||||
"""Strict (extra='forbid')."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
|
||||
|
||||
class MemberInvitePayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
email: EmailStr
|
||||
role: MemberAssignableRole
|
||||
|
||||
|
||||
class MemberRoleUpdatePayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
role: MemberAssignableRole
|
||||
|
||||
|
||||
class MemberInviteResponse(BaseModel):
|
||||
result: Literal["success"] = "success"
|
||||
email: str
|
||||
role: str
|
||||
member_id: str
|
||||
invite_url: str
|
||||
tenant_id: str
|
||||
|
||||
|
||||
class MemberActionResponse(BaseModel):
|
||||
result: Literal["success"] = "success"
|
||||
|
||||
@ -4,7 +4,7 @@ from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
@ -17,18 +17,17 @@ from controllers.openapi._models import (
|
||||
SessionRow,
|
||||
WorkspacePayload,
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
SubjectType,
|
||||
Scope,
|
||||
TokenType,
|
||||
get_auth_ctx,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.rate_limit import (
|
||||
LIMIT_ME_PER_ACCOUNT,
|
||||
LIMIT_ME_PER_EMAIL,
|
||||
enforce,
|
||||
)
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -42,32 +41,18 @@ from services.oauth_device_flow import (
|
||||
@openapi_ns.route("/account")
|
||||
class AccountApi(Resource):
|
||||
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = get_auth_ctx()
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{auth_data.account_id}")
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
|
||||
else:
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return AccountResponse(
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email,
|
||||
subject_issuer=ctx.subject_issuer,
|
||||
account=None,
|
||||
workspaces=[],
|
||||
default_workspace_id=None,
|
||||
).model_dump(mode="json")
|
||||
|
||||
account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None
|
||||
memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
|
||||
account_id_str = str(auth_data.account_id) if auth_data.account_id else None
|
||||
account = AccountService.get_account_by_id(db.session, account_id_str) if account_id_str else None
|
||||
memberships = TenantService.get_account_memberships(db.session, account_id_str) if account_id_str else []
|
||||
default_ws_id = _pick_default_workspace(memberships)
|
||||
|
||||
return AccountResponse(
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email or (account.email if account else None),
|
||||
subject_type="account",
|
||||
subject_email=account.email if account else None,
|
||||
account=_account_payload(account) if account else None,
|
||||
workspaces=[_workspace_payload(m) for m in memberships],
|
||||
default_workspace_id=default_ws_id,
|
||||
@ -77,19 +62,17 @@ class AccountApi(Resource):
|
||||
@openapi_ns.route("/account/sessions/self")
|
||||
class AccountSessionsSelfApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self):
|
||||
ctx = get_auth_ctx()
|
||||
_require_oauth_subject(ctx)
|
||||
revoke_oauth_token(db.session, redis_client, str(ctx.token_id))
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def delete(self, *, auth_data: AuthData):
|
||||
revoke_oauth_token(db.session, redis_client, str(auth_data.token_id))
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions")
|
||||
class AccountSessionsApi(Resource):
|
||||
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
ctx = get_auth_ctx()
|
||||
now = datetime.now(UTC)
|
||||
page = int(request.args.get("page", "1"))
|
||||
@ -122,10 +105,9 @@ class AccountSessionsApi(Resource):
|
||||
@openapi_ns.route("/account/sessions/<string:session_id>")
|
||||
class AccountSessionByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self, session_id: str):
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def delete(self, session_id: str, *, auth_data: AuthData):
|
||||
ctx = get_auth_ctx()
|
||||
_require_oauth_subject(ctx)
|
||||
|
||||
# 404 (not 403) on cross-subject so the endpoint doesn't leak
|
||||
# token IDs that belong to other subjects.
|
||||
@ -136,13 +118,6 @@ class AccountSessionByIdApi(Resource):
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
def _require_oauth_subject(ctx: AuthContext) -> None:
|
||||
if not ctx.source.startswith("oauth"):
|
||||
raise BadRequest(
|
||||
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
|
||||
)
|
||||
|
||||
|
||||
def _iso(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
@ -16,7 +16,8 @@ import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import AppRunRequest
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
@ -124,8 +125,9 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
|
||||
class AppRunApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
|
||||
@openapi_ns.response(200, "Run result (SSE stream)")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
payload = AppRunRequest.model_validate(body)
|
||||
@ -158,8 +160,9 @@ class AppRunApi(Resource):
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
|
||||
class AppRunTaskStopApi(Resource):
|
||||
@openapi_ns.response(200, "Task stopped")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
return {"result": "success"}
|
||||
|
||||
@ -1,9 +1,4 @@
|
||||
"""GET /openapi/v1/apps and per-app reads.
|
||||
|
||||
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
|
||||
is last → outermost → publishes the auth ContextVar before `require_scope`
|
||||
reads it.
|
||||
"""
|
||||
"""GET /openapi/v1/apps and per-app reads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -28,31 +23,17 @@ from controllers.openapi._models import (
|
||||
AppListRow,
|
||||
TagItem,
|
||||
)
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
Scope,
|
||||
SubjectType,
|
||||
get_auth_ctx,
|
||||
require_scope,
|
||||
require_workspace_member,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppListParams, AppService
|
||||
from services.tag_service import TagService
|
||||
|
||||
_APPS_READ_DECORATORS = [
|
||||
require_scope(Scope.APPS_READ),
|
||||
accept_subjects(SubjectType.ACCOUNT),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
]
|
||||
|
||||
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
|
||||
|
||||
|
||||
@ -66,13 +47,9 @@ _EMPTY_PARAMETERS: dict[str, Any] = {
|
||||
|
||||
|
||||
class AppReadResource(Resource):
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
|
||||
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
|
||||
ctx: AuthContext = get_auth_ctx()
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for membership/exists checks."""
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> App:
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(app_id)
|
||||
is_uuid = True
|
||||
@ -99,8 +76,7 @@ class AppReadResource(Resource):
|
||||
raise Conflict("".join(lines))
|
||||
app = matches[0]
|
||||
|
||||
require_workspace_member(ctx, str(app.tenant_id))
|
||||
return app, ctx
|
||||
return app
|
||||
|
||||
|
||||
def parameters_payload(app: App) -> dict:
|
||||
@ -114,13 +90,14 @@ def parameters_payload(app: App) -> dict:
|
||||
class AppDescribeApi(AppReadResource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
|
||||
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
|
||||
def get(self, app_id: str):
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, app_id: str, *, auth_data: AuthData):
|
||||
try:
|
||||
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
app, _ = self._load(app_id, workspace_id=query.workspace_id)
|
||||
app = self._load(app_id, workspace_id=query.workspace_id)
|
||||
|
||||
requested = query.fields
|
||||
want_info = requested is None or "info" in requested
|
||||
@ -168,20 +145,16 @@ class AppDescribeApi(AppReadResource):
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
|
||||
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
|
||||
def get(self):
|
||||
ctx: AuthContext = get_auth_ctx()
|
||||
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
try:
|
||||
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
workspace_id = query.workspace_id
|
||||
require_workspace_member(ctx, workspace_id)
|
||||
|
||||
empty = (
|
||||
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
|
||||
@ -237,7 +210,7 @@ class AppListApi(Resource):
|
||||
openapi_visible=True,
|
||||
)
|
||||
|
||||
pagination = AppService().get_paginate_apps(str(ctx.account_id), workspace_id, params)
|
||||
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params)
|
||||
if pagination is None:
|
||||
return empty
|
||||
|
||||
|
||||
@ -18,37 +18,27 @@ from controllers.openapi._models import (
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
)
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData, Edition
|
||||
from extensions.ext_database import db
|
||||
from libs.device_flow_security import enterprise_only
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
Scope,
|
||||
SubjectType,
|
||||
require_scope,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.app_permitted_service import list_permitted_apps
|
||||
from services.openapi.license_gate import license_required
|
||||
|
||||
|
||||
@openapi_ns.route("/permitted-external-apps")
|
||||
class PermittedExternalAppsListApi(Resource):
|
||||
method_decorators = [
|
||||
require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL),
|
||||
license_required,
|
||||
accept_subjects(SubjectType.EXTERNAL_SSO),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
enterprise_only,
|
||||
]
|
||||
|
||||
@openapi_ns.response(
|
||||
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
|
||||
)
|
||||
def get(self):
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_READ_PERMITTED_EXTERNAL,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}),
|
||||
edition=frozenset({Edition.EE}),
|
||||
)
|
||||
def get(self, *, auth_data: AuthData):
|
||||
try:
|
||||
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
|
||||
__all__ = ["OAUTH_BEARER_PIPELINE"]
|
||||
__all__ = ["auth_router"]
|
||||
|
||||
@ -1,46 +1,64 @@
|
||||
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
|
||||
|
||||
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
|
||||
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
|
||||
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
|
||||
inline — they don't need `AppAuthzCheck`/`CallerMount`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
SurfaceCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
from controllers.openapi.auth.conditions import (
|
||||
EDITION_CE,
|
||||
EDITION_EE,
|
||||
LOADED_APP_IS_PRIVATE,
|
||||
PATH_HAS_APP_ID,
|
||||
WEBAPP_AUTH_ENABLED,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
AppAuthzStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
from controllers.openapi.auth.data import Edition
|
||||
from controllers.openapi.auth.flow import When
|
||||
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
|
||||
from controllers.openapi.auth.prepare import (
|
||||
load_account,
|
||||
load_app,
|
||||
load_app_access_mode,
|
||||
load_tenant,
|
||||
resolve_external_user,
|
||||
)
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
return AclStrategy()
|
||||
return MembershipStrategy()
|
||||
|
||||
|
||||
OAUTH_BEARER_PIPELINE = Pipeline(
|
||||
BearerCheck(),
|
||||
SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})),
|
||||
ScopeCheck(),
|
||||
AppResolver(),
|
||||
WorkspaceMembershipCheck(),
|
||||
AppAuthzCheck(_resolve_app_authz_strategy),
|
||||
CallerMount(AccountMounter(), EndUserMounter()),
|
||||
from controllers.openapi.auth.verify import (
|
||||
check_acl,
|
||||
check_app_access,
|
||||
check_membership,
|
||||
check_private_app_permission,
|
||||
check_scope,
|
||||
)
|
||||
from libs.oauth_bearer import TokenType
|
||||
|
||||
account_pipeline = AuthPipeline(
|
||||
prepare=[
|
||||
When(PATH_HAS_APP_ID, then=load_app),
|
||||
When(PATH_HAS_APP_ID, then=load_tenant),
|
||||
load_account, # all tokens here are account tokens
|
||||
When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
check_scope,
|
||||
When(EDITION_CE & PATH_HAS_APP_ID, then=check_membership),
|
||||
When(EDITION_EE & PATH_HAS_APP_ID & ~WEBAPP_AUTH_ENABLED, then=check_app_access),
|
||||
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
],
|
||||
)
|
||||
|
||||
external_sso_pipeline = AuthPipeline(
|
||||
prepare=[
|
||||
When(PATH_HAS_APP_ID, then=load_app),
|
||||
When(PATH_HAS_APP_ID, then=load_tenant),
|
||||
When(PATH_HAS_APP_ID, then=resolve_external_user),
|
||||
When(PATH_HAS_APP_ID, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
check_scope,
|
||||
When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
],
|
||||
)
|
||||
|
||||
auth_router = PipelineRouter(
|
||||
{
|
||||
TokenType.OAUTH_ACCOUNT: PipelineRoute(account_pipeline),
|
||||
TokenType.OAUTH_EXTERNAL_SSO: PipelineRoute(external_sso_pipeline, required_edition=frozenset({Edition.EE})),
|
||||
}
|
||||
)
|
||||
|
||||
53
api/controllers/openapi/auth/conditions.py
Normal file
53
api/controllers/openapi/auth/conditions.py
Normal file
@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from controllers.openapi.auth.data import AuthData, Edition, RequestContext, current_edition
|
||||
from libs.oauth_bearer import TokenType
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
CondFn = Callable[[RequestContext, AuthData | None], bool]
|
||||
|
||||
|
||||
class Cond:
|
||||
def __init__(self, fn: CondFn) -> None:
|
||||
self._fn = fn
|
||||
|
||||
def __call__(self, ctx: RequestContext, data: AuthData | None = None) -> bool:
|
||||
return self._fn(ctx, data)
|
||||
|
||||
def __and__(self, other: Cond) -> Cond:
|
||||
return Cond(lambda ctx, data: self(ctx, data) and other(ctx, data))
|
||||
|
||||
def __or__(self, other: Cond) -> Cond:
|
||||
return Cond(lambda ctx, data: self(ctx, data) or other(ctx, data))
|
||||
|
||||
def __invert__(self) -> Cond:
|
||||
return Cond(lambda ctx, data: not self(ctx, data))
|
||||
|
||||
|
||||
def request_cond(fn: Callable[[RequestContext], bool]) -> Cond:
|
||||
return Cond(lambda ctx, _: fn(ctx))
|
||||
|
||||
|
||||
def data_cond(fn: Callable[[AuthData], bool]) -> Cond:
|
||||
return Cond(lambda _, data: data is not None and fn(data))
|
||||
|
||||
|
||||
def config_cond(fn: Callable[[], bool]) -> Cond:
|
||||
return Cond(lambda _, __: fn())
|
||||
|
||||
|
||||
TOKEN_IS_OAUTH_ACCOUNT = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_ACCOUNT)
|
||||
TOKEN_IS_OAUTH_EXTERNAL_SSO = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_EXTERNAL_SSO)
|
||||
|
||||
PATH_HAS_APP_ID = request_cond(lambda ctx: "app_id" in ctx.path_params)
|
||||
|
||||
EDITION_CE = config_cond(lambda: current_edition() == Edition.CE)
|
||||
EDITION_EE = config_cond(lambda: current_edition() == Edition.EE)
|
||||
EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
|
||||
|
||||
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
|
||||
|
||||
LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE)
|
||||
@ -1,68 +0,0 @@
|
||||
"""Mutable per-request context for the openapi auth pipeline.
|
||||
|
||||
Every field starts None / empty and is filled in by a step. The pipeline
|
||||
is the only thing that should construct or mutate Context — handlers
|
||||
read populated values via the decorator's kwargs unpacking.
|
||||
|
||||
Context is intentionally decoupled from Flask's ``Request``: the pipeline
|
||||
guard extracts whatever transport-level inputs the steps need (bearer
|
||||
token, path params) at the boundary and writes them into Context fields,
|
||||
so steps stay testable without a request object and won't leak coupling
|
||||
to a specific framework.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from contextvars import Token
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal, Protocol
|
||||
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import App, Tenant
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
required_scope: Scope
|
||||
bearer_token: str | None = None
|
||||
path_params: Mapping[str, str] = field(default_factory=dict)
|
||||
subject_type: SubjectType | None = None
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope] = field(default_factory=frozenset)
|
||||
token_id: uuid.UUID | None = None
|
||||
token_hash: str | None = None
|
||||
cached_verified_tenants: dict[str, bool] | None = None
|
||||
source: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
caller: object | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
auth_ctx_reset_token: Token[AuthContext] | None = None
|
||||
|
||||
@property
|
||||
def must_tenant(self) -> Tenant:
|
||||
if not self.tenant:
|
||||
raise Unauthorized("tenant is not associated")
|
||||
return self.tenant
|
||||
|
||||
@property
|
||||
def must_subject_type(self) -> SubjectType:
|
||||
if not self.subject_type:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
return self.subject_type
|
||||
|
||||
|
||||
class Step(Protocol):
|
||||
"""One responsibility. Mutate ctx or raise to short-circuit."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None: ...
|
||||
69
api/controllers/openapi/auth/data.py
Normal file
69
api/controllers/openapi/auth/data.py
Normal file
@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from configs import dify_config
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import Account, Tenant
|
||||
from models.model import App, EndUser
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
|
||||
class Edition(StrEnum):
|
||||
CE = "ce"
|
||||
EE = "ee"
|
||||
SAAS = "saas"
|
||||
|
||||
|
||||
def current_edition() -> Edition:
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
return Edition.SAAS
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return Edition.EE
|
||||
return Edition.CE
|
||||
|
||||
|
||||
class ExternalIdentity(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
email: str
|
||||
issuer: str | None = None
|
||||
|
||||
|
||||
class RequestContext(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
token_type: TokenType
|
||||
scope: Scope | None = None
|
||||
path_params: dict[str, str]
|
||||
|
||||
|
||||
class AuthData(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
required_scope: Scope | None = None
|
||||
token_type: TokenType
|
||||
account_id: uuid.UUID | None = None
|
||||
token_hash: str
|
||||
token_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope]
|
||||
tenants: dict[str, bool] = Field(default_factory=dict)
|
||||
external_identity: ExternalIdentity | None = None
|
||||
path_params: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
app_access_mode: WebAppAccessMode | None = None
|
||||
|
||||
caller: Account | EndUser | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
|
||||
def require_app_context(self) -> tuple[App, Account | EndUser, Literal["account", "end_user"]]:
|
||||
if self.app is None or self.caller is None or self.caller_kind is None:
|
||||
raise InternalServerError("pipeline_invariant_violated: app context missing")
|
||||
return self.app, self.caller, self.caller_kind
|
||||
19
api/controllers/openapi/auth/flow.py
Normal file
19
api/controllers/openapi/auth/flow.py
Normal file
@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from controllers.openapi.auth.conditions import Cond
|
||||
from controllers.openapi.auth.data import AuthData, RequestContext
|
||||
|
||||
|
||||
class When:
|
||||
def __init__(self, condition: Cond, *, then: Callable[[Any], None]) -> None:
|
||||
self.condition = condition
|
||||
self._step = then
|
||||
|
||||
def applies(self, ctx: RequestContext, data: AuthData | None = None) -> bool:
|
||||
return self.condition(ctx, data)
|
||||
|
||||
def __call__(self, arg: Any) -> None:
|
||||
self._step(arg)
|
||||
@ -1,51 +1,209 @@
|
||||
"""Pipeline IS the auth scheme.
|
||||
"""Auth pipeline — entry point for all openapi auth.
|
||||
|
||||
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
|
||||
that is the design lock-in: forgetting an auth layer is structurally
|
||||
impossible because there is no "sometimes wrap, sometimes don't" choice.
|
||||
`PipelineRouter.guard()` is the only attachment point for endpoints.
|
||||
`AuthPipeline` is a pure step-runner with no routing concerns.
|
||||
`PipelineRoute` binds a pipeline to optional edition requirements.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context, Step
|
||||
from libs.oauth_bearer import Scope, extract_bearer, reset_auth_ctx
|
||||
from controllers.openapi._audit import emit_wrong_surface
|
||||
from controllers.openapi.auth.data import (
|
||||
AuthData,
|
||||
Edition,
|
||||
ExternalIdentity,
|
||||
RequestContext,
|
||||
current_edition,
|
||||
)
|
||||
from controllers.openapi.auth.flow import When
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
Scope,
|
||||
TokenType,
|
||||
extract_bearer,
|
||||
get_authenticator,
|
||||
reset_auth_ctx,
|
||||
set_auth_ctx,
|
||||
)
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self, *steps: Step) -> None:
|
||||
self._steps = steps
|
||||
class AuthPipeline:
|
||||
"""Pure step-runner — no routing, no guard.
|
||||
|
||||
def run(self, ctx: Context) -> None:
|
||||
for step in self._steps:
|
||||
step(ctx)
|
||||
Both `prepare` and `auth` steps receive the same `AuthData` instance.
|
||||
`prepare` steps populate it; `auth` steps validate it.
|
||||
"""
|
||||
|
||||
def guard(self, *, scope: Scope):
|
||||
def decorator(view):
|
||||
def __init__(self, prepare: list, auth: list) -> None:
|
||||
self._prepare = prepare
|
||||
self._auth = auth
|
||||
|
||||
def _run(
|
||||
self,
|
||||
identity: AuthContext,
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
view: Callable,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
) -> Any:
|
||||
req_ctx = RequestContext(
|
||||
token_type=identity.token_type,
|
||||
scope=scope,
|
||||
path_params=dict(request.view_args or {}),
|
||||
)
|
||||
|
||||
data = AuthData(
|
||||
token_type=identity.token_type,
|
||||
account_id=identity.account_id,
|
||||
token_hash=identity.token_hash,
|
||||
token_id=identity.token_id,
|
||||
scopes=frozenset(identity.scopes),
|
||||
tenants=dict(identity.verified_tenants),
|
||||
required_scope=scope,
|
||||
path_params=dict(req_ctx.path_params),
|
||||
external_identity=(
|
||||
ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer)
|
||||
if identity.subject_email
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
for step in self._prepare:
|
||||
if _should_run(step, req_ctx, data=None):
|
||||
step(data)
|
||||
|
||||
for step in self._auth:
|
||||
if _should_run(step, req_ctx, data=data):
|
||||
step(data)
|
||||
|
||||
reset_token = set_auth_ctx(identity)
|
||||
if data.caller:
|
||||
_mount_flask_login(data.caller)
|
||||
|
||||
try:
|
||||
kwargs["auth_data"] = data
|
||||
return view(*args, **kwargs)
|
||||
finally:
|
||||
reset_auth_ctx(reset_token)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineRoute:
|
||||
pipeline: AuthPipeline
|
||||
required_edition: frozenset[Edition] | None = None
|
||||
|
||||
|
||||
class PipelineRouter:
|
||||
"""Entry point for openapi auth.
|
||||
|
||||
`guard()` is the decorator that endpoints attach to. It applies
|
||||
global gates (edition, token type) then dispatches to the matching
|
||||
`PipelineRoute` for the token type.
|
||||
"""
|
||||
|
||||
def __init__(self, routes: dict[TokenType, PipelineRoute]) -> None:
|
||||
self._routes = routes
|
||||
|
||||
def guard(
|
||||
self,
|
||||
*,
|
||||
scope: Scope | None = None,
|
||||
allowed_token_types: frozenset[TokenType] | None = None,
|
||||
edition: frozenset[Edition] | None = None,
|
||||
) -> Callable:
|
||||
def decorator(view: Callable) -> Callable:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# Extract transport-level inputs at the boundary so steps
|
||||
# stay decoupled from Flask's request object.
|
||||
ctx = Context(
|
||||
required_scope=scope,
|
||||
bearer_token=extract_bearer(request),
|
||||
path_params=dict(request.view_args or {}),
|
||||
def decorated(*args: Any, **kwargs: Any) -> Any:
|
||||
return self._execute(
|
||||
args,
|
||||
kwargs,
|
||||
view,
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
)
|
||||
try:
|
||||
self.run(ctx)
|
||||
kwargs.update(
|
||||
app_model=ctx.app,
|
||||
caller=ctx.caller,
|
||||
caller_kind=ctx.caller_kind,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
finally:
|
||||
if ctx.auth_ctx_reset_token is not None:
|
||||
reset_auth_ctx(ctx.auth_ctx_reset_token)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
|
||||
def _execute(
|
||||
self,
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
view: Callable,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
allowed_token_types: frozenset[TokenType] | None,
|
||||
edition: frozenset[Edition] | None,
|
||||
) -> Any:
|
||||
# 404 not 403 — this edition doesn't expose the feature at all
|
||||
if edition is not None and current_edition() not in edition:
|
||||
raise NotFound()
|
||||
|
||||
license_checked = False
|
||||
if edition is not None and Edition.EE in edition:
|
||||
_check_license()
|
||||
license_checked = True
|
||||
|
||||
token = extract_bearer(request)
|
||||
if not token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
identity = get_authenticator().authenticate(token)
|
||||
|
||||
if allowed_token_types is not None and identity.token_type not in allowed_token_types:
|
||||
emit_wrong_surface(
|
||||
subject_type=_subject_type_str(identity),
|
||||
attempted_path=request.path,
|
||||
client_id=getattr(identity, "client_id", None),
|
||||
token_id=str(identity.token_id) if identity.token_id else None,
|
||||
)
|
||||
raise Forbidden("unsupported_token_type")
|
||||
|
||||
route = self._routes.get(identity.token_type)
|
||||
if route is None:
|
||||
raise Forbidden("unsupported_token_type")
|
||||
|
||||
if route.required_edition is not None:
|
||||
if current_edition() not in route.required_edition:
|
||||
raise Forbidden("external_sso_requires_ee")
|
||||
if not license_checked and Edition.EE in route.required_edition:
|
||||
_check_license()
|
||||
|
||||
return route.pipeline._run(identity, args, kwargs, view, scope=scope)
|
||||
|
||||
|
||||
def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool:
|
||||
if isinstance(step, When):
|
||||
return step.applies(req_ctx, data)
|
||||
return True
|
||||
|
||||
|
||||
def _subject_type_str(identity: Any) -> str | None:
|
||||
subject = getattr(identity, "subject_type", None)
|
||||
if subject is None:
|
||||
return None
|
||||
return subject.value if hasattr(subject, "value") else str(subject)
|
||||
|
||||
|
||||
def _check_license() -> None:
|
||||
settings = FeatureService.get_system_features()
|
||||
if settings.license.status in {LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST}:
|
||||
raise Forbidden("license_invalid")
|
||||
|
||||
|
||||
def _mount_flask_login(user: Any) -> None:
|
||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore[attr-defined]
|
||||
user_logged_in.send(current_app._get_current_object(), user=user) # type: ignore[attr-defined]
|
||||
|
||||
67
api/controllers/openapi/auth/prepare.py
Normal file
67
api/controllers/openapi/auth/prepare.py
Normal file
@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantStatus
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
|
||||
|
||||
|
||||
def load_app(data: AuthData) -> None:
|
||||
app_id = data.path_params["app_id"]
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
data.app = app
|
||||
|
||||
|
||||
def load_tenant(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
raise InternalServerError("pipeline_invariant_violated: app not loaded before load_tenant")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, str(data.app.tenant_id))
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
data.tenant = tenant
|
||||
|
||||
|
||||
def load_account(data: AuthData) -> None:
|
||||
account = AccountService.get_account_by_id(db.session, str(data.account_id))
|
||||
if account is None:
|
||||
raise Unauthorized("account not found")
|
||||
if data.tenant:
|
||||
account.current_tenant = data.tenant
|
||||
data.caller = account
|
||||
data.caller_kind = "account"
|
||||
|
||||
|
||||
def resolve_external_user(data: AuthData) -> None:
|
||||
if data.tenant is None or data.app is None or data.external_identity is None:
|
||||
raise Unauthorized("missing context for external user resolution")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=str(data.tenant.id),
|
||||
app_id=str(data.app.id),
|
||||
user_id=data.external_identity.email,
|
||||
)
|
||||
data.caller = end_user
|
||||
data.caller_kind = "end_user"
|
||||
|
||||
|
||||
def load_app_access_mode(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
return
|
||||
try:
|
||||
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(data.app.id))
|
||||
if settings is None:
|
||||
data.app_access_mode = None
|
||||
return
|
||||
data.app_access_mode = WebAppAccessMode(settings.access_mode)
|
||||
except ValueError:
|
||||
data.app_access_mode = None
|
||||
77
api/controllers/openapi/auth/role_gate.py
Normal file
77
api/controllers/openapi/auth/role_gate.py
Normal file
@ -0,0 +1,77 @@
|
||||
"""Workspace role gate.
|
||||
|
||||
Layered on top of `validate_bearer` + `accept_subjects(SubjectType.ACCOUNT)`
|
||||
for routes whose access depends on the caller's `TenantAccountJoin.role`
|
||||
in the workspace named by the `workspace_id` path parameter.
|
||||
|
||||
Usage::
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members")
|
||||
class Members(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
@require_workspace_role() # any member
|
||||
def get(self, workspace_id: str): ...
|
||||
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def post(self, workspace_id: str): ...
|
||||
|
||||
Non-member callers get 404 (matching `GET /openapi/v1/workspaces/<id>`)
|
||||
so workspace IDs do not leak across tenants. A member without one of the
|
||||
allowed roles gets 403.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TypeVar
|
||||
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import try_get_auth_ctx
|
||||
from models.account import TenantAccountRole
|
||||
from services.account_service import TenantService
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
|
||||
def require_workspace_role(*allowed_roles: TenantAccountRole) -> Callable[[F], F]:
|
||||
"""Gate a route on the caller's role in ``workspace_id``.
|
||||
|
||||
Pass no roles to require only membership. Pass one or more roles to
|
||||
require the caller's role be in that set.
|
||||
"""
|
||||
|
||||
allowed = frozenset(allowed_roles)
|
||||
|
||||
def deco(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: object, **kwargs: object) -> object:
|
||||
ctx = try_get_auth_ctx()
|
||||
if ctx is None or ctx.account_id is None:
|
||||
raise RuntimeError(
|
||||
"require_workspace_role called without account-bearer context; "
|
||||
"stack validate_bearer + accept_subjects(SubjectType.ACCOUNT) above it"
|
||||
)
|
||||
|
||||
workspace_id = kwargs.get("workspace_id")
|
||||
if not workspace_id:
|
||||
raise RuntimeError("require_workspace_role expects a 'workspace_id' route parameter")
|
||||
|
||||
role = TenantService.get_account_role_in_tenant(db.session, str(ctx.account_id), str(workspace_id))
|
||||
|
||||
if role is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
if allowed and role not in allowed:
|
||||
raise Forbidden("insufficient workspace role")
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return deco
|
||||
@ -1,170 +0,0 @@
|
||||
"""Pipeline steps. Each is one responsibility.
|
||||
|
||||
`BearerCheck` is the only step that touches the token registry; downstream
|
||||
steps see only the populated `Context`. `BearerCheck` also publishes the
|
||||
resolved identity to the openapi auth ``ContextVar`` (the same one the
|
||||
decorator-level :func:`libs.oauth_bearer.validate_bearer` writes to) so the
|
||||
surface gate and any handler reading the request-scoped context has a single
|
||||
source of truth across both auth-attach paths. The reset token is stashed
|
||||
on `ctx.auth_ctx_reset_token`; `Pipeline.guard` resets the ContextVar in
|
||||
its `finally` so worker-thread reuse can't leak identity across requests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
|
||||
from controllers.openapi.auth.surface_gate import check_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
InvalidBearerError,
|
||||
Scope,
|
||||
SubjectType,
|
||||
check_workspace_membership,
|
||||
get_authenticator,
|
||||
set_auth_ctx,
|
||||
)
|
||||
from models import TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
class BearerCheck:
|
||||
"""Resolve bearer → populate identity fields. Rate-limit is enforced
|
||||
inside `BearerAuthenticator.authenticate`, so no separate step here.
|
||||
Also publishes the resolved `AuthContext` via
|
||||
:func:`libs.oauth_bearer.set_auth_ctx` — same shape the decorator-level
|
||||
``validate_bearer`` writes — so the surface gate + downstream readers
|
||||
don't see two different identity sources. The reset token is parked on
|
||||
``ctx.auth_ctx_reset_token`` for `Pipeline.guard` to consume."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not ctx.bearer_token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
try:
|
||||
authn = get_authenticator().authenticate(ctx.bearer_token)
|
||||
except InvalidBearerError as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
ctx.subject_type = authn.subject_type
|
||||
ctx.subject_email = authn.subject_email
|
||||
ctx.subject_issuer = authn.subject_issuer
|
||||
ctx.account_id = authn.account_id
|
||||
ctx.scopes = frozenset(authn.scopes)
|
||||
ctx.source = authn.source
|
||||
ctx.token_id = authn.token_id
|
||||
ctx.expires_at = authn.expires_at
|
||||
ctx.token_hash = authn.token_hash
|
||||
ctx.cached_verified_tenants = dict(authn.verified_tenants)
|
||||
ctx.auth_ctx_reset_token = set_auth_ctx(authn)
|
||||
|
||||
|
||||
class ScopeCheck:
|
||||
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
class SurfaceCheck:
|
||||
"""Reject the request if the resolved subject is not in `accepted`."""
|
||||
|
||||
def __init__(self, *, accepted: frozenset[SubjectType]) -> None:
|
||||
self._accepted = accepted
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
check_surface(self._accepted)
|
||||
|
||||
|
||||
class AppResolver:
|
||||
"""Read ``app_id`` from ``ctx.path_params``; populate ctx.app + ctx.tenant.
|
||||
|
||||
Every endpoint using the OAuth bearer pipeline must declare
|
||||
``<string:app_id>`` in its route — that is the design lock-in (no body /
|
||||
header coupling). ``Pipeline.guard`` lifts ``request.view_args`` into
|
||||
``ctx.path_params`` at the boundary so this step doesn't need to know
|
||||
about the request object.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
app_id = ctx.path_params.get("app_id")
|
||||
if not app_id:
|
||||
raise BadRequest("app_id is required in path")
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
ctx.app, ctx.tenant = app, tenant
|
||||
|
||||
|
||||
class WorkspaceMembershipCheck:
|
||||
"""Layer 0 — workspace membership gate.
|
||||
|
||||
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
|
||||
(dfoa_) only — SSO subjects skip.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
if ctx.subject_type != SubjectType.ACCOUNT:
|
||||
return
|
||||
if ctx.account_id is None or ctx.tenant is None:
|
||||
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
|
||||
if ctx.token_hash is None:
|
||||
raise Unauthorized("token_hash unset — BearerCheck did not run")
|
||||
|
||||
check_workspace_membership(
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=ctx.must_tenant.id,
|
||||
token_hash=ctx.token_hash,
|
||||
cached_verdicts=ctx.cached_verified_tenants or {},
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzCheck:
|
||||
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
|
||||
self._resolve = resolve_strategy
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not self._resolve().authorize(ctx):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
class CallerMount:
|
||||
def __init__(self, *mounters: CallerMounter) -> None:
|
||||
self._mounters = mounters
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if ctx.subject_type is None:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
for m in self._mounters:
|
||||
if m.applies_to(ctx.must_subject_type):
|
||||
m.mount(ctx)
|
||||
return
|
||||
raise Unauthorized("no caller mounter for subject type")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AppAuthzCheck",
|
||||
"AppResolver",
|
||||
"AuthContext",
|
||||
"BearerCheck",
|
||||
"CallerMount",
|
||||
"ScopeCheck",
|
||||
"SurfaceCheck",
|
||||
"WorkspaceMembershipCheck",
|
||||
]
|
||||
@ -1,168 +0,0 @@
|
||||
"""Strategy classes for the openapi auth pipeline.
|
||||
|
||||
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
|
||||
vary along independent axes; each strategy is one class so the pipeline
|
||||
composition stays a flat list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import user_logged_in
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import (
|
||||
EnterpriseService,
|
||||
WebAppAccessMode,
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzStrategy(Protocol):
|
||||
def authorize(self, ctx: Context) -> bool: ...
|
||||
|
||||
|
||||
class AclStrategy:
|
||||
"""Per-app ACL, evaluated in two stages.
|
||||
|
||||
The EE gateway has already enforced tenancy and workspace membership
|
||||
by the time this strategy runs, so AclStrategy only owns per-app ACL:
|
||||
|
||||
1. Subject vs access-mode compatibility (pure rule table). External-SSO
|
||||
bearers belong to public-facing apps only; account bearers cover the
|
||||
full set. A mismatch is an immediate deny — no IO.
|
||||
2. For modes that pair with the subject, decide whether the inner
|
||||
permission API must run. Only `PRIVATE` (per-app selected-user list)
|
||||
requires it; the remaining modes are pass-through.
|
||||
"""
|
||||
|
||||
_ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = {
|
||||
SubjectType.ACCOUNT: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
WebAppAccessMode.PRIVATE_ALL,
|
||||
WebAppAccessMode.PRIVATE,
|
||||
}
|
||||
),
|
||||
SubjectType.EXTERNAL_SSO: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
_MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE})
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
access_mode = self._fetch_access_mode(ctx.app.id)
|
||||
if access_mode is None:
|
||||
return False
|
||||
if not self._subject_allowed_for_mode(ctx.must_subject_type, access_mode):
|
||||
return False
|
||||
if access_mode not in self._MODES_REQUIRING_INNER_CHECK:
|
||||
return True
|
||||
return self._inner_permission_check(ctx)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None:
|
||||
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
|
||||
if settings is None:
|
||||
return None
|
||||
try:
|
||||
return WebAppAccessMode(settings.access_mode)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool:
|
||||
return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset())
|
||||
|
||||
def _inner_permission_check(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
user_id = self._resolve_user_id(ctx)
|
||||
if user_id is None:
|
||||
return False
|
||||
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id,
|
||||
app_id=ctx.app.id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_id(ctx: Context) -> str | None:
|
||||
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||
return str(ctx.account_id) if ctx.account_id is not None else None
|
||||
if ctx.subject_email is None:
|
||||
return None
|
||||
account = AccountService.get_account_by_email(db.session, ctx.subject_email)
|
||||
return str(account.id) if account is not None else None
|
||||
|
||||
|
||||
class MembershipStrategy:
|
||||
"""Tenant-membership fallback.
|
||||
|
||||
Used when webapp-auth is disabled (CE deployment). Account-bearing
|
||||
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
|
||||
denied (it requires the webapp-auth surface).
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return False
|
||||
if ctx.tenant is None:
|
||||
return False
|
||||
return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id)
|
||||
|
||||
|
||||
def _login_as(user) -> None:
|
||||
"""Set Flask-Login request user so downstream services see the caller."""
|
||||
current_app.login_manager._update_request_context_with_user(user) # type:ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=user) # type:ignore
|
||||
|
||||
|
||||
class CallerMounter(Protocol):
|
||||
def applies_to(self, subject_type: SubjectType) -> bool: ...
|
||||
|
||||
def mount(self, ctx: Context) -> None: ...
|
||||
|
||||
|
||||
class AccountMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.ACCOUNT
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.account_id is None:
|
||||
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
|
||||
account = AccountService.get_account_by_id(db.session, str(ctx.account_id))
|
||||
if account is None:
|
||||
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
|
||||
account.current_tenant = ctx.must_tenant
|
||||
_login_as(account)
|
||||
ctx.caller, ctx.caller_kind = account, "account"
|
||||
|
||||
|
||||
class EndUserMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.EXTERNAL_SSO
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
|
||||
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=ctx.tenant.id,
|
||||
app_id=ctx.app.id,
|
||||
user_id=ctx.subject_email,
|
||||
)
|
||||
_login_as(end_user)
|
||||
ctx.caller, ctx.caller_kind = end_user, "end_user"
|
||||
82
api/controllers/openapi/auth/verify.py
Normal file
82
api/controllers/openapi/auth/verify.py
Normal file
@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType, check_workspace_membership
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
|
||||
|
||||
|
||||
def check_scope(data: AuthData) -> None:
|
||||
if data.required_scope is None:
|
||||
return
|
||||
if Scope.FULL in data.scopes or data.required_scope in data.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
def check_membership(data: AuthData) -> None:
|
||||
if data.tenant is None:
|
||||
raise Unauthorized("tenant unset")
|
||||
if data.account_id is None:
|
||||
raise Unauthorized("account_id unset")
|
||||
check_workspace_membership(
|
||||
account_id=data.account_id,
|
||||
tenant_id=data.tenant.id,
|
||||
token_hash=data.token_hash,
|
||||
membership_cache=data.tenants,
|
||||
)
|
||||
|
||||
|
||||
def check_app_access(data: AuthData) -> None:
|
||||
if data.tenant is None:
|
||||
return
|
||||
if not TenantService.account_belongs_to_tenant(db.session, data.account_id, data.tenant.id):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
_ALLOWED_MODES_BY_TOKEN_TYPE: dict[TokenType, frozenset[WebAppAccessMode]] = {
|
||||
TokenType.OAUTH_ACCOUNT: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
WebAppAccessMode.PRIVATE_ALL,
|
||||
WebAppAccessMode.PRIVATE,
|
||||
}
|
||||
),
|
||||
TokenType.OAUTH_EXTERNAL_SSO: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def check_acl(data: AuthData) -> None:
|
||||
if data.app is None or data.app_access_mode is None:
|
||||
raise Forbidden("app or access mode not loaded")
|
||||
allowed_modes = _ALLOWED_MODES_BY_TOKEN_TYPE.get(data.token_type, frozenset())
|
||||
if data.app_access_mode not in allowed_modes:
|
||||
raise Forbidden("subject_not_allowed_for_access_mode")
|
||||
|
||||
|
||||
def check_private_app_permission(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
raise Forbidden("app not loaded")
|
||||
user_id = _resolve_user_id(data)
|
||||
if user_id is None:
|
||||
raise Forbidden("cannot resolve user for private app check")
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id=user_id, app_id=data.app.id):
|
||||
raise Forbidden("user_not_allowed_for_private_app")
|
||||
|
||||
|
||||
def _resolve_user_id(data: AuthData) -> str | None:
|
||||
if data.token_type == TokenType.OAUTH_ACCOUNT:
|
||||
return str(data.account_id) if data.account_id is not None else None
|
||||
if data.external_identity is None:
|
||||
return None
|
||||
account = AccountService.get_account_by_email(db.session, data.external_identity.email)
|
||||
return str(account.id) if account is not None else None
|
||||
@ -17,11 +17,11 @@ from controllers.common.errors import (
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse
|
||||
from libs.oauth_bearer import Scope
|
||||
from models import Account, App
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -39,8 +39,9 @@ class AppFileUploadApi(Resource):
|
||||
}
|
||||
)
|
||||
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller: Account, caller_kind: str):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, _ = auth_data.require_app_context()
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
if len(request.files) > 1:
|
||||
|
||||
@ -17,7 +17,8 @@ from werkzeug.exceptions import BadRequest, NotFound
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import to_timestamp
|
||||
@ -55,8 +56,9 @@ def _ensure_form_is_allowed_for_openapi(form) -> None:
|
||||
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
|
||||
class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
@openapi_ns.response(200, "Form definition")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, form_token: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
@ -69,8 +71,9 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@openapi_ns.response(200, "Form submitted")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, form_token: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
|
||||
@ -17,7 +17,8 @@ from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
@ -28,7 +29,7 @@ from core.workflow.human_input_policy import HumanInputSurface
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
@ -36,8 +37,9 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/events")
|
||||
class OpenApiWorkflowEventsApi(Resource):
|
||||
@openapi_ns.response(200, "SSE event stream")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
|
||||
raise UnprocessableEntity("mode_not_supported_for_event_reconnect")
|
||||
|
||||
@ -1,41 +1,129 @@
|
||||
"""User-scoped workspace reads under /openapi/v1/workspaces. Bearer-authed
|
||||
counterparts to the cookie-authed /console/api/workspaces endpoints.
|
||||
"""User-scoped workspace reads and member management under /openapi/v1/workspaces.
|
||||
|
||||
Account bearers (dfoa_) see every tenant they're a member of. External
|
||||
SSO bearers (dfoe_) have no account_id and so see an empty list — that
|
||||
matches /openapi/v1/account.
|
||||
Bearer-authed counterparts to the cookie-authed /console/api/workspaces
|
||||
endpoints. Account bearers (dfoa_) see every tenant they're a member of.
|
||||
External SSO bearers (dfoe_) have no account_id and so see an empty list —
|
||||
that matches /openapi/v1/account.
|
||||
|
||||
Member-management endpoints are gated by both `accept_subjects` (SSO out)
|
||||
and `require_workspace_role` (membership / role lookup against the path's
|
||||
``workspace_id``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import starmap
|
||||
from urllib import parse
|
||||
|
||||
from flask import jsonify, make_response, request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import WorkspaceDetailResponse, WorkspaceListResponse, WorkspaceSummaryResponse
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
SubjectType,
|
||||
get_auth_ctx,
|
||||
validate_bearer,
|
||||
from controllers.openapi._models import (
|
||||
MemberActionResponse,
|
||||
MemberInvitePayload,
|
||||
MemberInviteResponse,
|
||||
MemberListQuery,
|
||||
MemberListResponse,
|
||||
MemberResponse,
|
||||
MemberRoleUpdatePayload,
|
||||
WorkspaceDetailResponse,
|
||||
WorkspaceListResponse,
|
||||
WorkspaceSummaryResponse,
|
||||
)
|
||||
from models import Tenant, TenantAccountJoin
|
||||
from services.account_service import TenantService
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.role_gate import require_workspace_role
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import Account, Tenant, TenantAccountJoin
|
||||
from models.account import TenantAccountRole, TenantStatus
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import (
|
||||
AccountAlreadyInTenantError,
|
||||
AccountNotLinkTenantError,
|
||||
AccountRegisterError,
|
||||
CannotOperateSelfError,
|
||||
MemberNotInTenantError,
|
||||
NoPermissionError,
|
||||
RoleAlreadyAssignedError,
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _validate_body[M: BaseModel](model: type[M]) -> M:
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
return model.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
def _member_response(account: Account) -> MemberResponse:
|
||||
return MemberResponse(
|
||||
id=str(account.id),
|
||||
name=account.name,
|
||||
email=account.email,
|
||||
role=account.role.value if account.role else "",
|
||||
status=account.status.value if account.status else "",
|
||||
avatar=account.avatar,
|
||||
)
|
||||
|
||||
|
||||
def _load_tenant(workspace_id: str) -> Tenant:
|
||||
tenant = TenantService.get_tenant_by_id(db.session, workspace_id)
|
||||
if tenant is None or tenant.status != TenantStatus.NORMAL:
|
||||
raise NotFound("workspace not found")
|
||||
return tenant
|
||||
|
||||
|
||||
def _load_account(account_id: object) -> Account:
|
||||
account = AccountService.get_account_by_id(db.session, str(account_id)) if account_id else None
|
||||
if account is None:
|
||||
raise RuntimeError("authenticated account_id has no Account row")
|
||||
return account
|
||||
|
||||
|
||||
def _quota_error(*, code: str, message: str, hint: str) -> Forbidden:
|
||||
err = Forbidden(message)
|
||||
err.response = make_response(
|
||||
jsonify({"code": code, "message": message, "hint": hint}),
|
||||
403,
|
||||
)
|
||||
return err
|
||||
|
||||
|
||||
def _check_member_invite_quota(tenant_id: str) -> None:
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
if 0 < members.limit <= members.size:
|
||||
raise _quota_error(
|
||||
code="members.limit_exceeded",
|
||||
message="Subscription member limit reached.",
|
||||
hint="Upgrade your plan to invite more members or remove an existing member first.",
|
||||
)
|
||||
|
||||
if features.workspace_members.enabled:
|
||||
if not features.workspace_members.is_available(1):
|
||||
raise _quota_error(
|
||||
code="workspace_members.license_exceeded",
|
||||
message="Workspace member license capacity reached.",
|
||||
hint="Contact your workspace administrator to expand the license seat count.",
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces")
|
||||
class WorkspacesApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def get(self):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
rows = TenantService.get_workspaces_for_account(db.session, str(ctx.account_id))
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
rows = TenantService.get_workspaces_for_account(db.session, str(auth_data.account_id))
|
||||
|
||||
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
|
||||
|
||||
@ -43,12 +131,9 @@ class WorkspacesApi(Resource):
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>")
|
||||
class WorkspaceByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def get(self, workspace_id: str):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id)
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData):
|
||||
row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id)
|
||||
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
@ -57,6 +142,172 @@ class WorkspaceByIdApi(Resource):
|
||||
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/switch")
|
||||
class WorkspaceSwitchApi(Resource):
|
||||
"""Server-side switch — equivalent to the console's POST /workspaces/switch.
|
||||
|
||||
CLI `difyctl use workspace <id>` calls this; it does NOT mutate
|
||||
``hosts.yml`` on its own. Failure here must abort the local write so
|
||||
that ``hosts.yml`` never diverges from the server's ``current`` state.
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role()
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
account = _load_account(auth_data.account_id)
|
||||
|
||||
try:
|
||||
TenantService.switch_tenant(account, workspace_id)
|
||||
except AccountNotLinkTenantError:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id)
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
tenant, membership = row
|
||||
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members")
|
||||
class WorkspaceMembersApi(Resource):
|
||||
"""List + invite members.
|
||||
|
||||
GET is any-member. POST requires admin/owner — owner can never be
|
||||
assigned through invite (ownership transfer is console-only).
|
||||
"""
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(MemberListQuery))
|
||||
@openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role()
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData):
|
||||
try:
|
||||
query = MemberListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
tenant = _load_tenant(workspace_id)
|
||||
members = TenantService.get_tenant_members(tenant)
|
||||
total = len(members)
|
||||
start = (query.page - 1) * query.limit
|
||||
page_items = members[start : start + query.limit]
|
||||
return MemberListResponse(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
total=total,
|
||||
has_more=query.page * query.limit < total,
|
||||
data=[_member_response(m) for m in page_items],
|
||||
).model_dump(mode="json"), 200
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__])
|
||||
@openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberInvitePayload)
|
||||
inviter = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
|
||||
_check_member_invite_quota(str(tenant.id))
|
||||
|
||||
try:
|
||||
token = RegisterService.invite_new_member(
|
||||
tenant=tenant,
|
||||
email=payload.email,
|
||||
language=None,
|
||||
role=payload.role,
|
||||
inviter=inviter,
|
||||
)
|
||||
except AccountAlreadyInTenantError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except NoPermissionError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except AccountRegisterError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
normalized_email = payload.email.lower()
|
||||
member = AccountService.get_account_by_email_with_case_fallback(normalized_email)
|
||||
if member is None:
|
||||
# invite_new_member just created or fetched this account.
|
||||
raise RuntimeError("invited member missing from DB after invite")
|
||||
|
||||
encoded_email = parse.quote(normalized_email)
|
||||
invite_url = f"{dify_config.CONSOLE_WEB_URL}/activate?email={encoded_email}&token={token}"
|
||||
return MemberInviteResponse(
|
||||
email=normalized_email,
|
||||
role=payload.role,
|
||||
member_id=str(member.id),
|
||||
invite_url=invite_url,
|
||||
tenant_id=str(tenant.id),
|
||||
).model_dump(mode="json"), 201
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members/<string:member_id>")
|
||||
class WorkspaceMemberApi(Resource):
|
||||
"""Remove a member.
|
||||
|
||||
Self-removal and owner-removal are explicitly rejected by the service
|
||||
layer (CannotOperateSelfError, NoPermissionError) — both surface as
|
||||
400 per the spec, with the service's message preserved.
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
operator = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
member = AccountService.get_account_by_id(db.session, member_id)
|
||||
if member is None:
|
||||
raise NotFound("member not found")
|
||||
|
||||
try:
|
||||
TenantService.remove_member_from_tenant(tenant, member, operator)
|
||||
except CannotOperateSelfError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except NoPermissionError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except MemberNotInTenantError as exc:
|
||||
raise NotFound(str(exc))
|
||||
|
||||
return MemberActionResponse().model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members/<string:member_id>/role")
|
||||
class WorkspaceMemberRoleApi(Resource):
|
||||
"""Change a member's role.
|
||||
|
||||
Owner cannot be assigned here (closed enum). Admin cannot demote the
|
||||
standing owner (service NoPermissionError → 400, per spec).
|
||||
"""
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__])
|
||||
@openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberRoleUpdatePayload)
|
||||
operator = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
member = AccountService.get_account_by_id(db.session, member_id)
|
||||
if member is None:
|
||||
raise NotFound("member not found")
|
||||
|
||||
try:
|
||||
TenantService.update_member_role(tenant, member, payload.role, operator)
|
||||
except CannotOperateSelfError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except NoPermissionError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except MemberNotInTenantError as exc:
|
||||
raise NotFound(str(exc))
|
||||
except RoleAlreadyAssignedError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
return MemberActionResponse().model_dump(mode="json"), 200
|
||||
|
||||
|
||||
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceSummaryResponse:
|
||||
return WorkspaceSummaryResponse(
|
||||
id=str(tenant.id),
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
from uuid import UUID
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
from fields.hit_testing_fields import HitTestingResponse
|
||||
from libs.helper import dump_response
|
||||
|
||||
register_schema_model(service_api_ns, HitTestingPayload)
|
||||
register_schema_models(service_api_ns, HitTestingPayload)
|
||||
register_response_schema_models(service_api_ns, HitTestingResponse)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
|
||||
@ -13,16 +16,16 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
@service_api_ns.doc("dataset_hit_testing")
|
||||
@service_api_ns.doc(description="Perform hit testing on a dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Hit testing results",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset not found",
|
||||
}
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Hit testing results",
|
||||
model=service_api_ns.models[HitTestingResponse.__name__],
|
||||
)
|
||||
@service_api_ns.response(401, "Unauthorized - invalid API token")
|
||||
@service_api_ns.response(404, "Dataset not found")
|
||||
@service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__])
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id: UUID):
|
||||
def post(self, tenant_id: str, dataset_id: UUID) -> dict[str, object]:
|
||||
"""Perform hit testing on a dataset.
|
||||
|
||||
Tests retrieval performance for the specified dataset.
|
||||
@ -33,4 +36,4 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
args = self.parse_args(service_api_ns.payload)
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
return dump_response(HitTestingResponse, self.perform_hit_testing(dataset, args))
|
||||
|
||||
@ -13,6 +13,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -140,20 +141,26 @@ def cloud_edition_billing_resource_check[**P, R](
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
if resource == "vector_space":
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
vector_space = FeatureService.get_vector_space(api_token.tenant_id)
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
features = FeatureService.get_features(api_token.tenant_id, exclude_vector_space=True)
|
||||
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
|
||||
if resource == "members" and 0 < members.limit <= members.size:
|
||||
raise Forbidden("The number of members has reached the limit of your subscription.")
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
raise Forbidden("The number of apps has reached the limit of your subscription.")
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.")
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
raise Forbidden("The number of documents has reached the limit of your subscription.")
|
||||
else:
|
||||
@ -174,7 +181,7 @@ def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
features = FeatureService.get_features(api_token.tenant_id, exclude_vector_space=True)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
|
||||
@ -12,7 +12,7 @@ from controllers.common.schema import register_response_schema_models, register_
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_webapp_passport
|
||||
from models.model import App, AppMode
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
@ -56,7 +56,7 @@ class AppParameterApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model: App, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
@ -92,7 +92,7 @@ class AppMeta(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model: App, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
"""Get app meta"""
|
||||
return AppService().get_app_meta(app_model)
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ from core.errors.error import (
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -86,7 +86,7 @@ class CompletionApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -140,7 +140,7 @@ class CompletionStopApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, app_model, end_user, task_id: str):
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -169,7 +169,7 @@ class ChatApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -226,7 +226,7 @@ class ChatStopApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, app_model, end_user, task_id: str):
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -19,7 +19,7 @@ from fields.conversation_fields import (
|
||||
SimpleConversation,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
@ -81,7 +81,7 @@ class ConversationListApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -127,7 +127,7 @@ class ConversationApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def delete(self, app_model, end_user, c_id: UUID):
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -166,7 +166,7 @@ class ConversationRenameApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user, c_id: UUID):
|
||||
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -204,7 +204,7 @@ class ConversationPinApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Conversation pinned successfully", web_ns.models[ResultResponse.__name__])
|
||||
def patch(self, app_model, end_user, c_id: UUID):
|
||||
def patch(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -235,7 +235,7 @@ class ConversationUnPinApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Conversation unpinned successfully", web_ns.models[ResultResponse.__name__])
|
||||
def patch(self, app_model, end_user, c_id: UUID):
|
||||
def patch(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -13,6 +13,7 @@ from controllers.web import web_ns
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse
|
||||
from models.model import App, EndUser
|
||||
from services.file_service import FileService
|
||||
|
||||
register_schema_models(web_ns, FileResponse)
|
||||
@ -31,7 +32,7 @@ class FileApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(201, "File uploaded successfully", web_ns.models[FileResponse.__name__])
|
||||
def post(self, app_model, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Upload a file for use in web applications.
|
||||
|
||||
Accepts file uploads for use within web applications, supporting
|
||||
|
||||
@ -27,7 +27,7 @@ from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfinite
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -81,7 +81,7 @@ class MessageListApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -133,7 +133,7 @@ class MessageFeedbackApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Feedback submitted successfully", web_ns.models[ResultResponse.__name__])
|
||||
def post(self, app_model, end_user, message_id: UUID):
|
||||
def post(self, app_model: App, end_user: EndUser, message_id: UUID):
|
||||
message_id_str = str(message_id)
|
||||
|
||||
payload = MessageFeedbackPayload.model_validate(web_ns.payload or {})
|
||||
@ -167,7 +167,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model, end_user, message_id: UUID):
|
||||
def get(self, app_model: App, end_user: EndUser, message_id: UUID):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -223,7 +223,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model, end_user, message_id: UUID):
|
||||
def get(self, app_model: App, end_user: EndUser, message_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -13,6 +13,7 @@ from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
from models.model import App, EndUser
|
||||
from services.file_service import FileService
|
||||
|
||||
from ..common.schema import register_response_schema_models, register_schema_models
|
||||
@ -41,7 +42,7 @@ class RemoteFileInfoApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Remote file info", web_ns.models[RemoteFileInfo.__name__])
|
||||
def get(self, app_model, end_user, url):
|
||||
def get(self, app_model: App, end_user: EndUser, url: str):
|
||||
"""Get information about a remote file.
|
||||
|
||||
Retrieves basic information about a file located at a remote URL,
|
||||
@ -85,7 +86,7 @@ class RemoteFileUploadApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(201, "Remote file uploaded", web_ns.models[FileWithSignedUrl.__name__])
|
||||
def post(self, app_model, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Upload a file from a remote URL.
|
||||
|
||||
Downloads a file from the provided remote URL and uploads it
|
||||
|
||||
@ -11,6 +11,7 @@ from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from models.model import App, EndUser
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
@ -43,7 +44,7 @@ class SavedMessageListApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -77,7 +78,7 @@ class SavedMessageListApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Message saved successfully", web_ns.models[ResultResponse.__name__])
|
||||
def post(self, app_model, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -106,7 +107,7 @@ class SavedMessageApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def delete(self, app_model, end_user, message_id: UUID):
|
||||
def delete(self, app_model: App, end_user: EndUser, message_id: UUID):
|
||||
message_id_str = str(message_id)
|
||||
|
||||
if app_model.mode != "completion":
|
||||
|
||||
@ -10,7 +10,7 @@ from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import AppIconUrlField
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
from models.model import App, EndUser, Site
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ class AppSiteApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@marshal_with(app_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
"""Retrieve app site info."""
|
||||
# get site
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
@ -78,10 +78,10 @@ class AppSiteApi(WebApiResource):
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
if app_model.tenant and app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id, exclude_vector_space=True).can_replace_logo
|
||||
|
||||
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)
|
||||
|
||||
@ -119,6 +119,6 @@ def serialize_site(site: Site) -> dict[str, Any]:
|
||||
|
||||
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id, exclude_vector_space=True).can_replace_logo
|
||||
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
|
||||
return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))
|
||||
|
||||
@ -27,6 +27,7 @@ from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import get_default_root_node_id
|
||||
from core.workflow.nodes.agent_v2.session_cleanup_layer import build_workflow_agent_session_cleanup_layer
|
||||
from core.workflow.system_variables import (
|
||||
build_bootstrap_variables,
|
||||
build_system_variables,
|
||||
@ -239,6 +240,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
workflow_entry.graph_engine.layer(build_workflow_agent_session_cleanup_layer())
|
||||
conversation_variable_layer = ConversationVariablePersistenceLayer(
|
||||
ConversationVariableUpdater(session_factory.get_session_maker())
|
||||
)
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AppGenerateEntity,
|
||||
EasyUIBasedAppGenerateEntity,
|
||||
@ -292,46 +293,51 @@ class AppRunner:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
text = ""
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
if not agent:
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
try:
|
||||
for result in invoke_result:
|
||||
if not agent:
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
message = result.delta.message
|
||||
if isinstance(message.content, str):
|
||||
text += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if isinstance(content, str):
|
||||
text += content
|
||||
elif isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
if message_id and user_id and tenant_id:
|
||||
try:
|
||||
self._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except Exception:
|
||||
_logger.exception("Failed to handle multimodal image output")
|
||||
message = result.delta.message
|
||||
if isinstance(message.content, str):
|
||||
text += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if isinstance(content, str):
|
||||
text += content
|
||||
elif isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
if message_id and user_id and tenant_id:
|
||||
try:
|
||||
self._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except Exception:
|
||||
_logger.exception("Failed to handle multimodal image output")
|
||||
else:
|
||||
_logger.warning("Received multimodal output but missing required parameters")
|
||||
else:
|
||||
_logger.warning("Received multimodal output but missing required parameters")
|
||||
else:
|
||||
text += content.data if hasattr(content, "data") else str(content)
|
||||
text += content.data if hasattr(content, "data") else str(content)
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = list(result.prompt_messages)
|
||||
if not prompt_messages:
|
||||
prompt_messages = list(result.prompt_messages)
|
||||
|
||||
if result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
if result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
except GenerateTaskStoppedError:
|
||||
# Explicitly close provider stream to stop in-flight token generation ASAP.
|
||||
invoke_result.close()
|
||||
raise
|
||||
|
||||
if usage is None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import get_default_root_node_id
|
||||
from core.workflow.nodes.agent_v2.session_cleanup_layer import build_workflow_agent_session_cleanup_layer
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@ -166,6 +167,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
workflow_entry.graph_engine.layer(build_workflow_agent_session_cleanup_layer())
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
|
||||
@ -47,6 +47,12 @@ from graphon.graph_events import (
|
||||
)
|
||||
from graphon.node_events import NodeRunResult
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from services.workflow.inspector_events import (
|
||||
publish_node_changed as _inspector_publish_node_changed,
|
||||
)
|
||||
from services.workflow.inspector_events import (
|
||||
publish_workflow_completed as _inspector_publish_workflow_completed,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -163,6 +169,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
_inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value))
|
||||
|
||||
def _handle_graph_run_partial_succeeded(self, event: GraphRunPartialSucceededEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
@ -173,6 +180,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
_inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value))
|
||||
|
||||
def _handle_graph_run_failed(self, event: GraphRunFailedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
@ -184,6 +192,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
self._fail_running_node_executions(error_message=event.error)
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
_inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value))
|
||||
|
||||
def _handle_graph_run_aborted(self, event: GraphRunAbortedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
@ -194,6 +203,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
self._fail_running_node_executions(error_message=execution.error_message or "")
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
_inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value))
|
||||
|
||||
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
@ -241,6 +251,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
created_at=event.start_at,
|
||||
)
|
||||
self._node_snapshots[event.id] = snapshot
|
||||
_inspector_publish_node_changed(workflow_run_id=execution.id_, node_id=event.node_id, status="running")
|
||||
|
||||
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
@ -248,6 +259,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
domain_execution.error = event.error
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
_inspector_publish_node_changed(
|
||||
workflow_run_id=self._get_workflow_execution().id_,
|
||||
node_id=domain_execution.node_id,
|
||||
status="retry",
|
||||
)
|
||||
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
@ -257,6 +273,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
finished_at=event.finished_at,
|
||||
)
|
||||
_inspector_publish_node_changed(
|
||||
workflow_run_id=self._get_workflow_execution().id_,
|
||||
node_id=domain_execution.node_id,
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
@ -267,6 +288,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
error=event.error,
|
||||
finished_at=event.finished_at,
|
||||
)
|
||||
_inspector_publish_node_changed(
|
||||
workflow_run_id=self._get_workflow_execution().id_,
|
||||
node_id=domain_execution.node_id,
|
||||
status="failed",
|
||||
)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
@ -277,6 +303,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
error=event.error,
|
||||
finished_at=event.finished_at,
|
||||
)
|
||||
_inspector_publish_node_changed(
|
||||
workflow_run_id=self._get_workflow_execution().id_,
|
||||
node_id=domain_execution.node_id,
|
||||
status="exception",
|
||||
)
|
||||
|
||||
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
|
||||
@ -534,7 +534,9 @@ class ProviderManager:
|
||||
cache_key = f"tenant:{tenant_id}:model_load_balancing_enabled"
|
||||
cache_result = redis_client.get(cache_key)
|
||||
if cache_result is None:
|
||||
model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled
|
||||
model_load_balancing_enabled = FeatureService.get_features(
|
||||
tenant_id, exclude_vector_space=True
|
||||
).model_load_balancing_enabled
|
||||
redis_client.setex(cache_key, 120, str(model_load_balancing_enabled))
|
||||
else:
|
||||
cache_result = cache_result.decode("utf-8")
|
||||
|
||||
@ -863,7 +863,7 @@ class ToolManager:
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def user_get_api_provider(cls, provider: str, tenant_id: str):
|
||||
def user_get_api_provider(cls, provider: str, tenant_id: str, mask: bool = True):
|
||||
"""
|
||||
get api provider
|
||||
"""
|
||||
@ -902,8 +902,10 @@ class ToolManager:
|
||||
tenant_id=tenant_id,
|
||||
controller=controller,
|
||||
)
|
||||
|
||||
masked_credentials = encrypter.mask_plugin_credentials(encrypter.decrypt(credentials))
|
||||
if mask:
|
||||
masked_credentials = encrypter.mask_plugin_credentials(encrypter.decrypt(credentials))
|
||||
else:
|
||||
masked_credentials = encrypter.decrypt(credentials)
|
||||
|
||||
try:
|
||||
icon = emoji_icon_adapter.validate_json(provider_obj.icon)
|
||||
|
||||
@ -6,7 +6,7 @@ from json.decoder import JSONDecodeError
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import httpx
|
||||
from flask import request
|
||||
from flask import has_request_context, request
|
||||
from yaml import YAMLError, safe_load
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -44,7 +44,7 @@ class ApiBasedToolSchemaParser:
|
||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||
|
||||
server_url = openapi["servers"][0]["url"]
|
||||
request_env = request.headers.get("X-Request-Env")
|
||||
request_env = request.headers.get("X-Request-Env") if has_request_context() else None
|
||||
if request_env:
|
||||
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
|
||||
server_url = matched_servers[0] if matched_servers else server_url
|
||||
|
||||
@ -475,6 +475,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
from core.workflow.nodes.agent_v2.file_tenant_validator import UploadFileTenantValidator
|
||||
from core.workflow.nodes.agent_v2.output_failure_orchestrator import OutputFailureOrchestrator
|
||||
from core.workflow.nodes.agent_v2.output_type_checker import PerOutputTypeChecker
|
||||
from core.workflow.nodes.agent_v2.session_store import WorkflowAgentRuntimeSessionStore
|
||||
|
||||
return {
|
||||
"binding_resolver": WorkflowAgentBindingResolver(),
|
||||
@ -494,6 +495,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
# outputs contain no file refs.
|
||||
"type_checker": PerOutputTypeChecker(file_validator=UploadFileTenantValidator()),
|
||||
"failure_orchestrator": OutputFailureOrchestrator(),
|
||||
"session_store": WorkflowAgentRuntimeSessionStore(),
|
||||
}
|
||||
return {
|
||||
"strategy_resolver": self._agent_strategy_resolver,
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
|
||||
from clients.agent_backend import (
|
||||
AgentBackendError,
|
||||
AgentBackendHTTPError,
|
||||
@ -17,11 +20,14 @@ from clients.agent_backend import (
|
||||
AgentBackendStreamInternalEvent,
|
||||
AgentBackendTransportError,
|
||||
AgentBackendValidationError,
|
||||
CleanupLayerSpec,
|
||||
extract_cleanup_layer_specs,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from graphon.entities.pause_reason import SchedulingPause
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
|
||||
from graphon.node_events import NodeEventBase, NodeRunResult, PauseRequestedEvent, StreamCompletedEvent
|
||||
from graphon.nodes.base.node import Node
|
||||
from models.agent_config_entities import WorkflowNodeJobConfig
|
||||
|
||||
@ -40,11 +46,14 @@ from .runtime_request_builder import (
|
||||
WorkflowAgentRuntimeRequestBuilder,
|
||||
WorkflowAgentRuntimeRequestBuildError,
|
||||
)
|
||||
from .session_store import WorkflowAgentRuntimeSessionStore, WorkflowAgentSessionScope
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Stage 4 §5+§7: the terminal events that `_consume_event_stream` may return.
|
||||
# Stream + started events are filtered out before we yield; transport errors
|
||||
@ -74,6 +83,7 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
output_adapter: WorkflowAgentOutputAdapter,
|
||||
type_checker: PerOutputTypeChecker,
|
||||
failure_orchestrator: OutputFailureOrchestrator,
|
||||
session_store: WorkflowAgentRuntimeSessionStore | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
@ -88,6 +98,7 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
self._output_adapter = output_adapter
|
||||
self._type_checker = type_checker
|
||||
self._failure_orchestrator = failure_orchestrator
|
||||
self._session_store = session_store
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
@ -134,6 +145,17 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
"agent_config_snapshot_id": bundle.snapshot.id,
|
||||
"binding_id": bundle.binding.id,
|
||||
}
|
||||
session_scope = WorkflowAgentSessionScope(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
binding_id=bundle.binding.id,
|
||||
agent_id=bundle.agent.id,
|
||||
agent_config_snapshot_id=bundle.snapshot.id,
|
||||
)
|
||||
|
||||
# Stage 4 §4.1 (D-3): use effective outputs so defaults flow through both
|
||||
# the backend request and the post-run type check.
|
||||
@ -147,6 +169,9 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
session_snapshot = None
|
||||
if self._session_store is not None:
|
||||
session_snapshot = self._session_store.load_active_snapshot(session_scope)
|
||||
runtime_request = self._runtime_request_builder.build(
|
||||
WorkflowAgentRuntimeBuildContext(
|
||||
dify_context=dify_ctx,
|
||||
@ -159,6 +184,7 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
agent=bundle.agent,
|
||||
snapshot=bundle.snapshot,
|
||||
attempt=attempt,
|
||||
session_snapshot=session_snapshot,
|
||||
)
|
||||
)
|
||||
except WorkflowAgentRuntimeRequestBuildError as error:
|
||||
@ -221,9 +247,35 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
# Non-success terminal (failed / cancelled / paused) skips per-output
|
||||
# post-processing — the backend itself already failed.
|
||||
if isinstance(terminal_event, AgentBackendRunPausedInternalEvent):
|
||||
self._save_session_snapshot(
|
||||
session_scope=session_scope,
|
||||
backend_run_id=terminal_event.run_id,
|
||||
snapshot=terminal_event.session_snapshot,
|
||||
composition_layer_specs=extract_cleanup_layer_specs(runtime_request.request.composition),
|
||||
metadata=metadata,
|
||||
)
|
||||
yield PauseRequestedEvent(
|
||||
reason=SchedulingPause(
|
||||
message=terminal_event.message
|
||||
or "Agent backend run requested workflow pause for external input."
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Non-success terminal (failed / cancelled) skips per-output
|
||||
# post-processing — the backend itself already failed. We also retire
|
||||
# the local ACTIVE session row so a workflow loop back into the same
|
||||
# Agent node cannot resume from a stale snapshot. The failed agent
|
||||
# backend layers (suspended per ``on_exit``) are left for agent
|
||||
# backend's own GC; this row will no longer be picked up by the
|
||||
# workflow-terminal cleanup layer.
|
||||
if not isinstance(terminal_event, AgentBackendRunSucceededInternalEvent):
|
||||
self._mark_session_cleaned_on_failure(
|
||||
session_scope=session_scope,
|
||||
backend_run_id=terminal_event.run_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=self._output_adapter.build_failure_result(
|
||||
event=terminal_event,
|
||||
@ -234,6 +286,14 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
self._save_session_snapshot(
|
||||
session_scope=session_scope,
|
||||
backend_run_id=terminal_event.run_id,
|
||||
snapshot=terminal_event.session_snapshot,
|
||||
composition_layer_specs=extract_cleanup_layer_specs(runtime_request.request.composition),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# ──── Stage 4: per-output type check ────
|
||||
type_check = self._type_checker.check(
|
||||
declared_outputs=effective_outputs,
|
||||
@ -384,6 +444,75 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
],
|
||||
}
|
||||
|
||||
def _save_session_snapshot(
|
||||
self,
|
||||
*,
|
||||
session_scope: WorkflowAgentSessionScope,
|
||||
backend_run_id: str,
|
||||
snapshot: CompositorSessionSnapshot | None,
|
||||
composition_layer_specs: list[CleanupLayerSpec],
|
||||
metadata: dict[str, Any],
|
||||
) -> None:
|
||||
if self._session_store is None:
|
||||
return
|
||||
try:
|
||||
self._session_store.save_active_snapshot(
|
||||
scope=session_scope,
|
||||
backend_run_id=backend_run_id,
|
||||
snapshot=snapshot,
|
||||
composition_layer_specs=composition_layer_specs,
|
||||
)
|
||||
agent_backend = dict(metadata.get("agent_backend") or {})
|
||||
agent_backend["session_snapshot_persisted"] = snapshot is not None
|
||||
metadata["agent_backend"] = agent_backend
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist workflow Agent runtime session snapshot: "
|
||||
"tenant_id=%s workflow_run_id=%s node_id=%s binding_id=%s agent_id=%s backend_run_id=%s",
|
||||
session_scope.tenant_id,
|
||||
session_scope.workflow_run_id,
|
||||
session_scope.node_id,
|
||||
session_scope.binding_id,
|
||||
session_scope.agent_id,
|
||||
backend_run_id,
|
||||
exc_info=True,
|
||||
)
|
||||
agent_backend = dict(metadata.get("agent_backend") or {})
|
||||
agent_backend["session_snapshot_persisted"] = False
|
||||
agent_backend["session_snapshot_persist_error"] = "workflow_agent_runtime_session_store_error"
|
||||
metadata["agent_backend"] = agent_backend
|
||||
|
||||
def _mark_session_cleaned_on_failure(
|
||||
self,
|
||||
*,
|
||||
session_scope: WorkflowAgentSessionScope,
|
||||
backend_run_id: str,
|
||||
metadata: dict[str, Any],
|
||||
) -> None:
|
||||
if self._session_store is None:
|
||||
return
|
||||
try:
|
||||
self._session_store.mark_cleaned(scope=session_scope, backend_run_id=backend_run_id)
|
||||
agent_backend = dict(metadata.get("agent_backend") or {})
|
||||
agent_backend["session_snapshot_cleaned_on_failure"] = True
|
||||
metadata["agent_backend"] = agent_backend
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to mark workflow Agent runtime session cleaned on agent run failure: "
|
||||
"tenant_id=%s workflow_run_id=%s node_id=%s binding_id=%s agent_id=%s backend_run_id=%s",
|
||||
session_scope.tenant_id,
|
||||
session_scope.workflow_run_id,
|
||||
session_scope.node_id,
|
||||
session_scope.binding_id,
|
||||
session_scope.agent_id,
|
||||
backend_run_id,
|
||||
exc_info=True,
|
||||
)
|
||||
agent_backend = dict(metadata.get("agent_backend") or {})
|
||||
agent_backend["session_snapshot_cleaned_on_failure"] = False
|
||||
agent_backend["session_snapshot_cleanup_error"] = "workflow_agent_runtime_session_store_error"
|
||||
metadata["agent_backend"] = agent_backend
|
||||
|
||||
@staticmethod
|
||||
def _patch_event_with_defaults(
|
||||
event: AgentBackendRunSucceededInternalEvent,
|
||||
|
||||
268
api/core/workflow/nodes/agent_v2/plugin_tools_builder.py
Normal file
268
api/core/workflow/nodes/agent_v2/plugin_tools_builder.py
Normal file
@ -0,0 +1,268 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
from dify_agent.layers.dify_plugin import (
|
||||
DifyPluginCredentialValue,
|
||||
DifyPluginToolConfig,
|
||||
DifyPluginToolCredentialType,
|
||||
DifyPluginToolParameter,
|
||||
DifyPluginToolParameterForm,
|
||||
DifyPluginToolsLayerConfig,
|
||||
)
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.errors import (
|
||||
ToolProviderCredentialValidationError,
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from models.agent_config_entities import AgentSoulDifyToolConfig, AgentSoulToolsConfig
|
||||
from models.provider_ids import ToolProviderID
|
||||
|
||||
|
||||
class WorkflowAgentPluginToolsBuildError(ValueError):
|
||||
"""Raised when Agent Soul tools cannot be prepared for Agent backend."""
|
||||
|
||||
def __init__(self, error_code: str, message: str) -> None:
|
||||
self.error_code = error_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentToolRuntimeProvider(Protocol):
|
||||
def get_agent_tool_runtime(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
agent_tool: AgentToolEntity,
|
||||
user_id: str | None = None,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Any | None = None,
|
||||
) -> Tool: ...
|
||||
|
||||
|
||||
class WorkflowAgentPluginToolsBuilder:
|
||||
"""Prepare Agent Soul Dify Plugin Tools for the public Agent backend DTO."""
|
||||
|
||||
def __init__(self, *, tool_runtime_provider: AgentToolRuntimeProvider | None = None) -> None:
|
||||
self._tool_runtime_provider = tool_runtime_provider or ToolManager
|
||||
|
||||
def build(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
user_id: str | None,
|
||||
tools: AgentSoulToolsConfig,
|
||||
invoke_from: InvokeFrom,
|
||||
) -> DifyPluginToolsLayerConfig | None:
|
||||
"""Resolve user-selected Dify Plugin Tools into the Agent backend DTO.
|
||||
|
||||
``invoke_from`` is the *real* runtime caller category (DEBUGGER for a
|
||||
Composer test run, SERVICE_API / WEB_APP for a published run). It must
|
||||
be threaded through to :class:`ToolManager` so credential quotas, rate
|
||||
limits, and audit tags match the actual call site.
|
||||
"""
|
||||
enabled_tools = [tool for tool in tools.dify_tools if tool.enabled]
|
||||
if not enabled_tools:
|
||||
return None
|
||||
|
||||
prepared: list[DifyPluginToolConfig] = []
|
||||
seen_names: set[str] = set()
|
||||
for tool_config in enabled_tools:
|
||||
agent_tool = self._to_agent_tool_entity(tool_config)
|
||||
tool_runtime = self._fetch_tool_runtime(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
agent_tool=agent_tool,
|
||||
invoke_from=invoke_from,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
exposed_name = self._exposed_tool_name(tool_config)
|
||||
if exposed_name in seen_names:
|
||||
raise WorkflowAgentPluginToolsBuildError(
|
||||
"agent_tool_name_duplicated",
|
||||
f"Duplicate Dify Plugin Tool name {exposed_name!r}.",
|
||||
)
|
||||
seen_names.add(exposed_name)
|
||||
|
||||
prepared.append(self._to_backend_tool_config(tool_config, tool_runtime, exposed_name))
|
||||
|
||||
return DifyPluginToolsLayerConfig(tools=prepared)
|
||||
|
||||
def _fetch_tool_runtime(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
user_id: str | None,
|
||||
agent_tool: AgentToolEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
tool_config: AgentSoulDifyToolConfig,
|
||||
) -> Tool:
|
||||
"""Resolve the API-side ``Tool`` runtime, mapping fetch errors to
|
||||
Inspector-friendly error codes so callers can render distinct UX for
|
||||
"tool definition gone" vs "credential failed".
|
||||
"""
|
||||
try:
|
||||
return self._tool_runtime_provider.get_agent_tool_runtime(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
agent_tool=agent_tool,
|
||||
user_id=user_id,
|
||||
invoke_from=invoke_from,
|
||||
variable_pool=None,
|
||||
)
|
||||
except ToolProviderNotFoundError as exc:
|
||||
raise WorkflowAgentPluginToolsBuildError(
|
||||
"agent_tool_declaration_not_found",
|
||||
f"Dify Plugin Tool {tool_config.tool_name!r} declaration not found: {exc}",
|
||||
) from exc
|
||||
except ToolProviderCredentialValidationError as exc:
|
||||
raise WorkflowAgentPluginToolsBuildError(
|
||||
"agent_tool_credential_invalid",
|
||||
f"Dify Plugin Tool {tool_config.tool_name!r} credential validation failed: {exc}",
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
# ToolManager raises bare ValueError when the agent tool's
|
||||
# ``runtime`` / runtime parameters are missing. Surface it under a
|
||||
# narrower error code than a generic "declaration not found" so
|
||||
# frontend can render an actionable hint.
|
||||
raise WorkflowAgentPluginToolsBuildError(
|
||||
"agent_tool_config_invalid",
|
||||
f"Dify Plugin Tool {tool_config.tool_name!r} runtime construction failed: {exc}",
|
||||
) from exc
|
||||
|
||||
@staticmethod
|
||||
def _to_agent_tool_entity(tool_config: AgentSoulDifyToolConfig) -> AgentToolEntity:
|
||||
return AgentToolEntity(
|
||||
provider_type=ToolProviderType.value_of(tool_config.provider_type),
|
||||
provider_id=WorkflowAgentPluginToolsBuilder._provider_id(tool_config),
|
||||
tool_name=tool_config.tool_name,
|
||||
tool_parameters=dict(tool_config.runtime_parameters),
|
||||
credential_id=tool_config.credential_ref.id if tool_config.credential_ref else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _provider_id(tool_config: AgentSoulDifyToolConfig) -> str:
|
||||
if tool_config.provider_id:
|
||||
return tool_config.provider_id
|
||||
assert tool_config.plugin_id is not None
|
||||
assert tool_config.provider is not None
|
||||
return f"{tool_config.plugin_id}/{tool_config.provider}"
|
||||
|
||||
@staticmethod
|
||||
def _exposed_tool_name(tool_config: AgentSoulDifyToolConfig) -> str:
|
||||
# Stage 3.1 decision: no user rename yet. Keep the model-visible tool
|
||||
# name aligned with the plugin declaration identity.
|
||||
return tool_config.tool_name
|
||||
|
||||
def _to_backend_tool_config(
|
||||
self,
|
||||
tool_config: AgentSoulDifyToolConfig,
|
||||
tool_runtime: Tool,
|
||||
exposed_name: str,
|
||||
) -> DifyPluginToolConfig:
|
||||
runtime = tool_runtime.runtime
|
||||
if runtime is None:
|
||||
raise WorkflowAgentPluginToolsBuildError(
|
||||
"agent_tool_config_invalid",
|
||||
f"Dify Plugin Tool {tool_config.tool_name!r} has no runtime.",
|
||||
)
|
||||
|
||||
provider_id = self._provider_id(tool_config)
|
||||
plugin_id, provider = self._plugin_provider(tool_config, provider_id)
|
||||
parameters = [
|
||||
DifyPluginToolParameter.model_validate(parameter.model_dump(mode="json"))
|
||||
for parameter in tool_runtime.get_merged_runtime_parameters()
|
||||
]
|
||||
runtime_parameters = self._runtime_parameters(tool_runtime, parameters)
|
||||
description = tool_config.description
|
||||
if description is None and tool_runtime.entity.description is not None:
|
||||
description = tool_runtime.entity.description.llm
|
||||
|
||||
return DifyPluginToolConfig(
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
tool_name=tool_config.tool_name,
|
||||
credential_type=self._credential_type(tool_config, runtime.credentials),
|
||||
name=exposed_name,
|
||||
description=description,
|
||||
credentials=self._normalize_credentials(runtime.credentials, tool_name=tool_config.tool_name),
|
||||
runtime_parameters=runtime_parameters,
|
||||
parameters=parameters,
|
||||
parameters_json_schema=cast(dict[str, Any], tool_runtime.get_llm_parameters_json_schema()),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _plugin_provider(tool_config: AgentSoulDifyToolConfig, provider_id: str) -> tuple[str, str]:
|
||||
if tool_config.plugin_id and tool_config.provider:
|
||||
return tool_config.plugin_id, tool_config.provider
|
||||
provider_id_entity = ToolProviderID(provider_id)
|
||||
return provider_id_entity.plugin_id, provider_id_entity.provider_name
|
||||
|
||||
@staticmethod
|
||||
def _credential_type(
|
||||
tool_config: AgentSoulDifyToolConfig,
|
||||
credentials: Mapping[str, Any],
|
||||
) -> DifyPluginToolCredentialType:
|
||||
if not credentials and tool_config.credential_type == "unauthorized":
|
||||
return "unauthorized"
|
||||
return tool_config.credential_type
|
||||
|
||||
@staticmethod
|
||||
def _runtime_parameters(
|
||||
tool_runtime: Tool,
|
||||
parameters: list[DifyPluginToolParameter],
|
||||
) -> dict[str, Any]:
|
||||
runtime = tool_runtime.runtime
|
||||
runtime_parameters = dict(runtime.runtime_parameters if runtime is not None else {})
|
||||
missing = [
|
||||
parameter.name
|
||||
for parameter in parameters
|
||||
if parameter.form is not DifyPluginToolParameterForm.LLM
|
||||
and parameter.required
|
||||
and parameter.default is None
|
||||
and parameter.name not in runtime_parameters
|
||||
]
|
||||
if missing:
|
||||
names = ", ".join(sorted(missing))
|
||||
raise WorkflowAgentPluginToolsBuildError(
|
||||
"agent_tool_runtime_parameter_missing",
|
||||
f"Dify Plugin Tool {tool_runtime.entity.identity.name!r} is missing runtime parameters: {names}.",
|
||||
)
|
||||
return runtime_parameters
|
||||
|
||||
@staticmethod
|
||||
def _normalize_credentials(
|
||||
credentials: Mapping[str, Any],
|
||||
*,
|
||||
tool_name: str,
|
||||
) -> dict[str, DifyPluginCredentialValue]:
|
||||
"""Forward only scalar credential values to the Agent backend.
|
||||
|
||||
``DifyPluginCredentialValue`` is ``str | int | float | bool | None``.
|
||||
Refusing non-scalar values (lists, dicts, custom objects) is safer than
|
||||
``str(value)`` — stringifying a nested OAuth token blob produces a
|
||||
Python ``repr`` that the plugin daemon cannot use, and we'd rather
|
||||
surface a clear ``agent_tool_credential_shape_invalid`` than send junk.
|
||||
"""
|
||||
normalized: dict[str, DifyPluginCredentialValue] = {}
|
||||
for key, value in credentials.items():
|
||||
if isinstance(value, str | int | float | bool) or value is None:
|
||||
normalized[key] = value
|
||||
continue
|
||||
raise WorkflowAgentPluginToolsBuildError(
|
||||
"agent_tool_credential_shape_invalid",
|
||||
(
|
||||
f"Dify Plugin Tool {tool_name!r} credential {key!r} has a non-scalar value "
|
||||
f"({type(value).__name__}); only str/int/float/bool/None are forwarded to the daemon."
|
||||
),
|
||||
)
|
||||
return normalized
|
||||
@ -11,13 +11,14 @@ SUPPORTED_AGENT_BACKEND_FEATURES = frozenset(
|
||||
"workflow_context",
|
||||
"model",
|
||||
"structured_output",
|
||||
"tools.dify_tools",
|
||||
}
|
||||
)
|
||||
|
||||
RESERVED_AGENT_BACKEND_FEATURES = frozenset(
|
||||
{
|
||||
"skills_files",
|
||||
"tools",
|
||||
"tools.cli_tools",
|
||||
"knowledge",
|
||||
"human",
|
||||
"env",
|
||||
@ -32,7 +33,7 @@ def build_runtime_feature_manifest(agent_soul: AgentSoulConfig) -> dict[str, Any
|
||||
warnings: list[dict[str, str]] = []
|
||||
soul_dump = agent_soul.model_dump(mode="json")
|
||||
for section in sorted(RESERVED_AGENT_BACKEND_FEATURES):
|
||||
value = soul_dump.get(section)
|
||||
value = _get_nested(soul_dump, section)
|
||||
has_value = bool(value)
|
||||
if isinstance(value, dict):
|
||||
has_value = any(bool(item) for item in value.values())
|
||||
@ -41,11 +42,12 @@ def build_runtime_feature_manifest(agent_soul: AgentSoulConfig) -> dict[str, Any
|
||||
{
|
||||
"section": f"agent_soul.{section}",
|
||||
"code": "agent_backend_layer_not_available",
|
||||
"message": f"{section} is saved in Agent Soul but is not executed by Agent backend in phase 3.",
|
||||
"message": f"{section} is saved in Agent Soul but is not executed by Agent backend.",
|
||||
}
|
||||
)
|
||||
|
||||
reserved_status = dict.fromkeys(sorted(RESERVED_AGENT_BACKEND_FEATURES), "reserved_not_executed")
|
||||
reserved_status["tools.dify_tools"] = "supported_when_config_valid"
|
||||
|
||||
return {
|
||||
"supported": sorted(SUPPORTED_AGENT_BACKEND_FEATURES),
|
||||
@ -53,3 +55,12 @@ def build_runtime_feature_manifest(agent_soul: AgentSoulConfig) -> dict[str, Any
|
||||
"reserved_status": reserved_status,
|
||||
"unsupported_runtime_warnings": warnings,
|
||||
}
|
||||
|
||||
|
||||
def _get_nested(value: dict[str, Any], path: str) -> Any:
|
||||
current: Any = value
|
||||
for part in path.split("."):
|
||||
if not isinstance(current, dict):
|
||||
return None
|
||||
current = current.get(part)
|
||||
return current
|
||||
|
||||
@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Protocol, cast
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
|
||||
from dify_agent.protocol import CreateRunRequest
|
||||
|
||||
@ -28,8 +29,10 @@ from models.agent_config_entities import (
|
||||
from models.agent_config_entities import (
|
||||
effective_declared_outputs as _effective_declared_outputs,
|
||||
)
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
from .output_failure_orchestrator import retry_idempotency_key
|
||||
from .plugin_tools_builder import WorkflowAgentPluginToolsBuilder, WorkflowAgentPluginToolsBuildError
|
||||
from .runtime_feature_manifest import build_runtime_feature_manifest
|
||||
|
||||
|
||||
@ -65,6 +68,7 @@ class WorkflowAgentRuntimeBuildContext:
|
||||
# Stage 4 §7 / D-4: 0 for the first run, then incremented per retry. Drives the
|
||||
# idempotency key so the backend treats each retry as a fresh request.
|
||||
attempt: int = 0
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
@ -84,9 +88,11 @@ class WorkflowAgentRuntimeRequestBuilder:
|
||||
*,
|
||||
credentials_provider: CredentialsProvider,
|
||||
request_builder: AgentBackendRunRequestBuilder | None = None,
|
||||
plugin_tools_builder: WorkflowAgentPluginToolsBuilder | None = None,
|
||||
) -> None:
|
||||
self._credentials_provider = credentials_provider
|
||||
self._request_builder = request_builder or AgentBackendRunRequestBuilder()
|
||||
self._plugin_tools_builder = plugin_tools_builder or WorkflowAgentPluginToolsBuilder()
|
||||
|
||||
def build(self, context: WorkflowAgentRuntimeBuildContext) -> WorkflowAgentRuntimeRequest:
|
||||
agent_soul = AgentSoulConfig.model_validate(context.snapshot.config_snapshot_dict)
|
||||
@ -102,15 +108,38 @@ class WorkflowAgentRuntimeRequestBuilder:
|
||||
workflow_job_prompt = node_job.workflow_prompt.strip() or "Run this workflow Agent Node for the current run."
|
||||
user_prompt = workflow_context_prompt.strip() or "Use the current workflow context."
|
||||
credentials = self._credentials_provider.fetch(agent_soul.model.model_provider, agent_soul.model.model)
|
||||
try:
|
||||
tools_layer = self._plugin_tools_builder.build(
|
||||
tenant_id=context.dify_context.tenant_id,
|
||||
app_id=context.dify_context.app_id,
|
||||
user_id=context.dify_context.user_id,
|
||||
tools=agent_soul.tools,
|
||||
# Thread the *real* runtime invocation source through to
|
||||
# ToolManager so credential quotas, rate limits, and audit
|
||||
# trails match the actual call site (DEBUGGER for draft test
|
||||
# run, SERVICE_API / WEB_APP for published run).
|
||||
invoke_from=context.dify_context.invoke_from,
|
||||
)
|
||||
except WorkflowAgentPluginToolsBuildError as error:
|
||||
raise WorkflowAgentRuntimeRequestBuildError(error.error_code, str(error)) from error
|
||||
if tools_layer is not None:
|
||||
metadata["agent_tools"] = {
|
||||
"dify_tool_count": len(tools_layer.tools),
|
||||
"dify_tool_names": [tool.name or tool.tool_name for tool in tools_layer.tools],
|
||||
"cli_tool_count": len(agent_soul.tools.cli_tools),
|
||||
}
|
||||
|
||||
request = self._request_builder.build_for_workflow_node(
|
||||
AgentBackendWorkflowNodeRunInput(
|
||||
model=AgentBackendModelConfig(
|
||||
plugin_id=agent_soul.model.plugin_id,
|
||||
model_provider=agent_soul.model.model_provider,
|
||||
plugin_id=self._plugin_daemon_plugin_id(
|
||||
plugin_id=agent_soul.model.plugin_id,
|
||||
model_provider=agent_soul.model.model_provider,
|
||||
),
|
||||
model_provider=self._plugin_daemon_provider_name(agent_soul.model.model_provider),
|
||||
model=agent_soul.model.model,
|
||||
credentials=self._normalize_credentials(credentials),
|
||||
model_settings=cast(dict[str, Any], agent_soul.model.model_settings),
|
||||
model_settings=agent_soul.model.model_settings,
|
||||
),
|
||||
# The execution-context layer is now the only public protocol
|
||||
# carrier for Dify tenant/user/run identifiers. ``user_id`` must
|
||||
@ -134,6 +163,8 @@ class WorkflowAgentRuntimeRequestBuilder:
|
||||
workflow_node_job_prompt=workflow_job_prompt,
|
||||
user_prompt=user_prompt,
|
||||
output=self._build_output_config(node_job.declared_outputs),
|
||||
tools=tools_layer,
|
||||
session_snapshot=context.session_snapshot,
|
||||
idempotency_key=self._idempotency_key(context),
|
||||
metadata=metadata,
|
||||
)
|
||||
@ -153,6 +184,20 @@ class WorkflowAgentRuntimeRequestBuilder:
|
||||
return "single_step"
|
||||
return "workflow_run"
|
||||
|
||||
@staticmethod
|
||||
def _plugin_daemon_plugin_id(*, plugin_id: str, model_provider: str) -> str:
|
||||
"""Return the transport plugin id expected by plugin-daemon headers."""
|
||||
if plugin_id.count("/") == 1:
|
||||
return plugin_id
|
||||
if plugin_id:
|
||||
return ModelProviderID(plugin_id).plugin_id
|
||||
return ModelProviderID(model_provider).plugin_id
|
||||
|
||||
@staticmethod
|
||||
def _plugin_daemon_provider_name(model_provider: str) -> str:
|
||||
"""Return the provider name expected by plugin-daemon dispatch payloads."""
|
||||
return ModelProviderID(model_provider).provider_name
|
||||
|
||||
@staticmethod
|
||||
def _idempotency_key(context: WorkflowAgentRuntimeBuildContext) -> str:
|
||||
# Stage 4 §7 / D-4: retries get distinct keys (``...:retry-{attempt}``) so
|
||||
|
||||
247
api/core/workflow/nodes/agent_v2/session_cleanup_layer.py
Normal file
247
api/core/workflow/nodes/agent_v2/session_cleanup_layer.py
Normal file
@ -0,0 +1,247 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from clients.agent_backend import AgentBackendError, AgentBackendRunClient, AgentBackendRunRequestBuilder
|
||||
from clients.agent_backend.factory import create_agent_backend_run_client
|
||||
from configs import dify_config
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .session_store import StoredWorkflowAgentSession, WorkflowAgentRuntimeSessionStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Upper bound on how long a cleanup-only run is allowed to settle before the
|
||||
# layer gives up and leaves the row ACTIVE so it can be retried later. Cleanup
|
||||
# work is mostly local agent-backend bookkeeping (no LLM inference), so 30s is
|
||||
# generous; a hung backend should never block workflow termination beyond this.
|
||||
_CLEANUP_WAIT_TIMEOUT_SECONDS = 30.0
|
||||
|
||||
|
||||
class WorkflowAgentSessionCleanupLayer(GraphEngineLayer):
|
||||
"""Retires workflow Agent session snapshots when a workflow reaches a terminal state.
|
||||
|
||||
Implementation notes — there are two failure modes the cleanup path has to
|
||||
avoid simultaneously:
|
||||
|
||||
1. The agenton compositor on the agent-backend side validates the cleanup
|
||||
request's session snapshot against the replayed composition before
|
||||
running any lifecycle hook. If the snapshot's layer names diverge from
|
||||
the composition, the run fails asynchronously with ``run_failed`` — but
|
||||
the initial ``POST /runs`` already returned 202, so the API side has no
|
||||
visibility of the failure unless it waits for terminal status. The
|
||||
``composition_layer_specs`` persistence in A.1–A.4 plus the
|
||||
``_filter_snapshot_to_specs`` shape in ``build_cleanup_request`` keeps
|
||||
the two name lists in sync.
|
||||
|
||||
2. The current agent backend's ``runner.py::_run_agent`` always invokes
|
||||
``run.get_layer("llm")`` and the structured-output / history validators
|
||||
before exiting any slot — there is no ``purpose: "cleanup"`` branch
|
||||
yet. A truly cleanup-only request (no LLM layer) therefore still
|
||||
crashes inside the runner with ``Layer 'llm' is not defined in this
|
||||
compositor run.``. Until the backend grows a cleanup-only purpose,
|
||||
this layer **does not issue an HTTP cleanup run**: it simply retires
|
||||
the local snapshot row so stale state cannot be re-resumed, and lets
|
||||
the agent backend's own retention TTL release the suspended layers.
|
||||
|
||||
The HTTP-cleanup machinery (``build_cleanup_request`` + ``wait_run``) is
|
||||
intentionally still wired into the request builder + integration tests so
|
||||
that when the agent backend supports cleanup runs we can flip the switch
|
||||
here with a one-line change (see ``_HTTP_CLEANUP_SUPPORTED``).
|
||||
"""
|
||||
|
||||
# Flip to True once dify-agent's runner has a ``purpose=cleanup`` branch
|
||||
# that skips the LLM/output/user-prompt invariants. Until then we only
|
||||
# update the local row; the spec list is still persisted so the future
|
||||
# HTTP cleanup path has everything it needs.
|
||||
_HTTP_CLEANUP_SUPPORTED: bool = False
|
||||
|
||||
_TERMINAL_EVENTS = (
|
||||
GraphRunSucceededEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunAbortedEvent,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_store: WorkflowAgentRuntimeSessionStore,
|
||||
request_builder: AgentBackendRunRequestBuilder,
|
||||
agent_backend_client: AgentBackendRunClient | None,
|
||||
cleanup_wait_timeout_seconds: float = _CLEANUP_WAIT_TIMEOUT_SECONDS,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._session_store = session_store
|
||||
self._request_builder = request_builder
|
||||
self._agent_backend_client = agent_backend_client
|
||||
self._cleanup_wait_timeout_seconds = cleanup_wait_timeout_seconds
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
return
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if not isinstance(event, self._TERMINAL_EVENTS):
|
||||
return
|
||||
workflow_run_id = get_system_text(
|
||||
self.graph_runtime_state.variable_pool,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID,
|
||||
)
|
||||
if not workflow_run_id:
|
||||
logger.warning("Skipping workflow Agent session cleanup: workflow_run_id is missing.")
|
||||
return
|
||||
|
||||
for stored_session in self._session_store.list_active_sessions(workflow_run_id=workflow_run_id):
|
||||
self._cleanup_session(stored_session)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
return
|
||||
|
||||
def _cleanup_session(self, stored_session: StoredWorkflowAgentSession) -> None:
|
||||
scope = stored_session.scope
|
||||
if not self._HTTP_CLEANUP_SUPPORTED:
|
||||
# Agent backend has no cleanup-only run mode yet (see class
|
||||
# docstring). Retire the local row so future re-entries do not
|
||||
# resume from stale state, and let the backend's retention TTL
|
||||
# release the suspended layers on its own schedule.
|
||||
logger.info(
|
||||
"Workflow Agent session retired locally; HTTP cleanup is disabled "
|
||||
"until the agent backend supports a cleanup-only run mode. "
|
||||
"workflow_run_id=%s node_id=%s binding_id=%s agent_id=%s previous_run_id=%s",
|
||||
scope.workflow_run_id,
|
||||
scope.node_id,
|
||||
scope.binding_id,
|
||||
scope.agent_id,
|
||||
stored_session.backend_run_id,
|
||||
)
|
||||
self._session_store.mark_cleaned(scope=scope, backend_run_id=stored_session.backend_run_id)
|
||||
return
|
||||
|
||||
if self._agent_backend_client is None:
|
||||
# HTTP cleanup was enabled by the caller but no client was wired
|
||||
# in (e.g. the API runs without AGENT_BACKEND_BASE_URL configured).
|
||||
# Leave the row ACTIVE so an operator restart with proper config
|
||||
# can drive the cleanup; do not silently retire it.
|
||||
logger.warning(
|
||||
"Skipping Agent backend cleanup: HTTP cleanup is enabled but no agent "
|
||||
"backend client is wired in. workflow_run_id=%s node_id=%s agent_id=%s",
|
||||
scope.workflow_run_id,
|
||||
scope.node_id,
|
||||
scope.agent_id,
|
||||
)
|
||||
return
|
||||
|
||||
if not stored_session.composition_layer_specs:
|
||||
# Sessions persisted before A.1 landed do not carry the spec list,
|
||||
# so we cannot replay a valid cleanup composition. Leave the row
|
||||
# ACTIVE and warn so the absence shows up in observability rather
|
||||
# than being silently swallowed by a doomed cleanup run.
|
||||
logger.warning(
|
||||
"Skipping Agent backend cleanup: no composition_layer_specs persisted. "
|
||||
"workflow_run_id=%s node_id=%s agent_id=%s",
|
||||
scope.workflow_run_id,
|
||||
scope.node_id,
|
||||
scope.agent_id,
|
||||
)
|
||||
return
|
||||
|
||||
request = self._request_builder.build_cleanup_request(
|
||||
session_snapshot=stored_session.session_snapshot,
|
||||
composition_layer_specs=stored_session.composition_layer_specs,
|
||||
idempotency_key=f"{scope.workflow_run_id}:{scope.node_id}:{scope.binding_id}:agent-session-cleanup",
|
||||
metadata={
|
||||
"tenant_id": scope.tenant_id,
|
||||
"app_id": scope.app_id,
|
||||
"workflow_id": scope.workflow_id,
|
||||
"workflow_run_id": scope.workflow_run_id,
|
||||
"node_id": scope.node_id,
|
||||
"node_execution_id": scope.node_execution_id,
|
||||
"binding_id": scope.binding_id,
|
||||
"agent_id": scope.agent_id,
|
||||
"agent_config_snapshot_id": scope.agent_config_snapshot_id,
|
||||
"previous_agent_backend_run_id": stored_session.backend_run_id,
|
||||
},
|
||||
)
|
||||
try:
|
||||
response = self._agent_backend_client.create_run(request)
|
||||
except AgentBackendError:
|
||||
logger.warning(
|
||||
"Agent backend session cleanup request failed: workflow_run_id=%s node_id=%s agent_id=%s",
|
||||
scope.workflow_run_id,
|
||||
scope.node_id,
|
||||
scope.agent_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
status_response = self._agent_backend_client.wait_run(
|
||||
response.run_id, timeout_seconds=self._cleanup_wait_timeout_seconds
|
||||
)
|
||||
except AgentBackendError:
|
||||
logger.warning(
|
||||
"Agent backend session cleanup wait_run failed: "
|
||||
"workflow_run_id=%s node_id=%s agent_id=%s cleanup_run_id=%s",
|
||||
scope.workflow_run_id,
|
||||
scope.node_id,
|
||||
scope.agent_id,
|
||||
response.run_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
if status_response.status != "succeeded":
|
||||
logger.warning(
|
||||
"Agent backend session cleanup did not succeed: status=%s error=%s "
|
||||
"workflow_run_id=%s node_id=%s agent_id=%s cleanup_run_id=%s",
|
||||
status_response.status,
|
||||
status_response.error,
|
||||
scope.workflow_run_id,
|
||||
scope.node_id,
|
||||
scope.agent_id,
|
||||
response.run_id,
|
||||
)
|
||||
return
|
||||
|
||||
self._session_store.mark_cleaned(scope=scope, backend_run_id=response.run_id)
|
||||
|
||||
|
||||
def build_workflow_agent_session_cleanup_layer() -> WorkflowAgentSessionCleanupLayer:
|
||||
"""Wire the cleanup layer with the standard production dependencies.
|
||||
|
||||
The agent backend client is constructed only when ``AGENT_BACKEND_BASE_URL``
|
||||
is configured (or the deterministic fake is explicitly enabled). When
|
||||
neither is set — for example unit tests that bring up the workflow runner
|
||||
without an Agent node — we pass ``None`` so the layer stays harmless. With
|
||||
``_HTTP_CLEANUP_SUPPORTED = False`` the local-retire branch never touches
|
||||
the client anyway, but keeping it ``None`` avoids importing httpx and lets
|
||||
test harnesses skip backend configuration.
|
||||
"""
|
||||
agent_backend_client: AgentBackendRunClient | None
|
||||
if dify_config.AGENT_BACKEND_USE_FAKE or dify_config.AGENT_BACKEND_BASE_URL:
|
||||
agent_backend_client = create_agent_backend_run_client(
|
||||
base_url=dify_config.AGENT_BACKEND_BASE_URL,
|
||||
use_fake=dify_config.AGENT_BACKEND_USE_FAKE,
|
||||
fake_scenario=dify_config.AGENT_BACKEND_FAKE_SCENARIO,
|
||||
)
|
||||
else:
|
||||
agent_backend_client = None
|
||||
|
||||
return WorkflowAgentSessionCleanupLayer(
|
||||
session_store=WorkflowAgentRuntimeSessionStore(),
|
||||
request_builder=AgentBackendRunRequestBuilder(),
|
||||
agent_backend_client=agent_backend_client,
|
||||
)
|
||||
179
api/core/workflow/nodes/agent_v2/session_store.py
Normal file
179
api/core/workflow/nodes/agent_v2/session_store.py
Normal file
@ -0,0 +1,179 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
||||
from clients.agent_backend.request_builder import CleanupLayerSpec
|
||||
from core.db.session_factory import session_factory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.agent import (
|
||||
WorkflowAgentRuntimeSession,
|
||||
WorkflowAgentRuntimeSessionStatus,
|
||||
)
|
||||
|
||||
_SPECS_ADAPTER: TypeAdapter[list[CleanupLayerSpec]] = TypeAdapter(list[CleanupLayerSpec])
|
||||
|
||||
|
||||
def _serialize_specs(specs: list[CleanupLayerSpec]) -> str:
|
||||
return _SPECS_ADAPTER.dump_json(specs).decode()
|
||||
|
||||
|
||||
def _deserialize_specs(value: str | None) -> list[CleanupLayerSpec]:
|
||||
if not value:
|
||||
return []
|
||||
return _SPECS_ADAPTER.validate_json(value)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class WorkflowAgentSessionScope:
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
workflow_run_id: str | None
|
||||
node_id: str
|
||||
node_execution_id: str
|
||||
binding_id: str
|
||||
agent_id: str
|
||||
agent_config_snapshot_id: str
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class StoredWorkflowAgentSession:
|
||||
scope: WorkflowAgentSessionScope
|
||||
session_snapshot: CompositorSessionSnapshot
|
||||
backend_run_id: str | None
|
||||
composition_layer_specs: list[CleanupLayerSpec] = field(default_factory=list)
|
||||
|
||||
|
||||
class WorkflowAgentRuntimeSessionStore:
|
||||
"""Stores Agent backend session snapshots for workflow Agent node re-entry."""
|
||||
|
||||
def load_active_snapshot(self, scope: WorkflowAgentSessionScope) -> CompositorSessionSnapshot | None:
|
||||
if scope.workflow_run_id is None:
|
||||
return None
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
row = session.scalar(
|
||||
select(WorkflowAgentRuntimeSession).where(
|
||||
WorkflowAgentRuntimeSession.tenant_id == scope.tenant_id,
|
||||
WorkflowAgentRuntimeSession.workflow_run_id == scope.workflow_run_id,
|
||||
WorkflowAgentRuntimeSession.node_id == scope.node_id,
|
||||
WorkflowAgentRuntimeSession.binding_id == scope.binding_id,
|
||||
WorkflowAgentRuntimeSession.agent_id == scope.agent_id,
|
||||
WorkflowAgentRuntimeSession.status == WorkflowAgentRuntimeSessionStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return CompositorSessionSnapshot.model_validate_json(row.session_snapshot)
|
||||
|
||||
def list_active_sessions(self, *, workflow_run_id: str) -> list[StoredWorkflowAgentSession]:
|
||||
with session_factory.create_session() as session:
|
||||
rows = session.scalars(
|
||||
select(WorkflowAgentRuntimeSession).where(
|
||||
WorkflowAgentRuntimeSession.workflow_run_id == workflow_run_id,
|
||||
WorkflowAgentRuntimeSession.status == WorkflowAgentRuntimeSessionStatus.ACTIVE,
|
||||
)
|
||||
).all()
|
||||
return [
|
||||
StoredWorkflowAgentSession(
|
||||
scope=WorkflowAgentSessionScope(
|
||||
tenant_id=row.tenant_id,
|
||||
app_id=row.app_id,
|
||||
workflow_id=row.workflow_id,
|
||||
workflow_run_id=row.workflow_run_id,
|
||||
node_id=row.node_id,
|
||||
node_execution_id=row.node_execution_id or "",
|
||||
binding_id=row.binding_id,
|
||||
agent_id=row.agent_id,
|
||||
agent_config_snapshot_id=row.agent_config_snapshot_id,
|
||||
),
|
||||
session_snapshot=CompositorSessionSnapshot.model_validate_json(row.session_snapshot),
|
||||
backend_run_id=row.backend_run_id,
|
||||
composition_layer_specs=_deserialize_specs(row.composition_layer_specs),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def save_active_snapshot(
|
||||
self,
|
||||
*,
|
||||
scope: WorkflowAgentSessionScope,
|
||||
backend_run_id: str,
|
||||
snapshot: CompositorSessionSnapshot | None,
|
||||
composition_layer_specs: list[CleanupLayerSpec],
|
||||
) -> None:
|
||||
if scope.workflow_run_id is None or snapshot is None:
|
||||
return
|
||||
|
||||
snapshot_json = snapshot.model_dump_json()
|
||||
specs_json = _serialize_specs(composition_layer_specs)
|
||||
with session_factory.create_session() as session:
|
||||
row = session.scalar(
|
||||
select(WorkflowAgentRuntimeSession).where(
|
||||
WorkflowAgentRuntimeSession.tenant_id == scope.tenant_id,
|
||||
WorkflowAgentRuntimeSession.workflow_run_id == scope.workflow_run_id,
|
||||
WorkflowAgentRuntimeSession.node_id == scope.node_id,
|
||||
WorkflowAgentRuntimeSession.binding_id == scope.binding_id,
|
||||
WorkflowAgentRuntimeSession.agent_id == scope.agent_id,
|
||||
)
|
||||
)
|
||||
if row is None:
|
||||
row = WorkflowAgentRuntimeSession(
|
||||
tenant_id=scope.tenant_id,
|
||||
app_id=scope.app_id,
|
||||
workflow_id=scope.workflow_id,
|
||||
workflow_run_id=scope.workflow_run_id,
|
||||
node_id=scope.node_id,
|
||||
node_execution_id=scope.node_execution_id,
|
||||
binding_id=scope.binding_id,
|
||||
agent_id=scope.agent_id,
|
||||
agent_config_snapshot_id=scope.agent_config_snapshot_id,
|
||||
backend_run_id=backend_run_id,
|
||||
session_snapshot=snapshot_json,
|
||||
composition_layer_specs=specs_json,
|
||||
status=WorkflowAgentRuntimeSessionStatus.ACTIVE,
|
||||
)
|
||||
session.add(row)
|
||||
else:
|
||||
row.node_execution_id = scope.node_execution_id
|
||||
row.agent_config_snapshot_id = scope.agent_config_snapshot_id
|
||||
row.backend_run_id = backend_run_id
|
||||
row.session_snapshot = snapshot_json
|
||||
row.composition_layer_specs = specs_json
|
||||
row.status = WorkflowAgentRuntimeSessionStatus.ACTIVE
|
||||
row.cleaned_at = None
|
||||
session.commit()
|
||||
|
||||
def mark_cleaned(self, *, scope: WorkflowAgentSessionScope, backend_run_id: str | None = None) -> None:
|
||||
if scope.workflow_run_id is None:
|
||||
return
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
row = session.scalar(
|
||||
select(WorkflowAgentRuntimeSession).where(
|
||||
WorkflowAgentRuntimeSession.tenant_id == scope.tenant_id,
|
||||
WorkflowAgentRuntimeSession.workflow_run_id == scope.workflow_run_id,
|
||||
WorkflowAgentRuntimeSession.node_id == scope.node_id,
|
||||
WorkflowAgentRuntimeSession.binding_id == scope.binding_id,
|
||||
WorkflowAgentRuntimeSession.agent_id == scope.agent_id,
|
||||
WorkflowAgentRuntimeSession.status == WorkflowAgentRuntimeSessionStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
if row is None:
|
||||
return
|
||||
if backend_run_id is not None:
|
||||
row.backend_run_id = backend_run_id
|
||||
row.status = WorkflowAgentRuntimeSessionStatus.CLEANED
|
||||
row.cleaned_at = naive_utc_now()
|
||||
session.commit()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"StoredWorkflowAgentSession",
|
||||
"WorkflowAgentRuntimeSessionStore",
|
||||
"WorkflowAgentSessionScope",
|
||||
]
|
||||
@ -126,6 +126,7 @@ class WorkflowAgentNodeValidator:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} requires Agent Soul model config."
|
||||
)
|
||||
cls._validate_agent_soul_tools(binding=binding, agent_soul=agent_soul)
|
||||
node_job = WorkflowNodeJobConfig.model_validate(binding.node_job_config_dict)
|
||||
cls.validate_node_job(session=session, binding=binding, node_job=node_job, topology=topology)
|
||||
|
||||
@ -280,6 +281,26 @@ class WorkflowAgentNodeValidator:
|
||||
f"Workflow Agent node {binding.node_id} references unsupported human contact channel {channel}."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate_agent_soul_tools(
|
||||
cls,
|
||||
*,
|
||||
binding: WorkflowAgentNodeBinding,
|
||||
agent_soul: AgentSoulConfig,
|
||||
) -> None:
|
||||
exposed_names: set[str] = set()
|
||||
for tool in agent_soul.tools.dify_tools:
|
||||
if not tool.enabled:
|
||||
continue
|
||||
exposed_name = tool.tool_name
|
||||
if exposed_name in exposed_names:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} has duplicate Dify Plugin Tool name {exposed_name}."
|
||||
)
|
||||
exposed_names.add(exposed_name)
|
||||
# CLI tools remain saved-but-not-executed. They are allowed at publish
|
||||
# time so existing Agent Soul drafts are not blocked by a reserved field.
|
||||
|
||||
@staticmethod
|
||||
def _validate_file_ref(
|
||||
*,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.workflow.human_input_adapter import adapt_node_config_for_graph
|
||||
@ -38,6 +39,14 @@ def handle(sender, **kwargs):
|
||||
identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}",
|
||||
)
|
||||
manager.delete_tool_parameters_cache()
|
||||
except ToolProviderNotFoundError as exc:
|
||||
logger.info(
|
||||
"Skipped deleting tool parameters cache for workflow %s node %s "
|
||||
"because tool provider is missing: %s",
|
||||
app.id,
|
||||
node_data.get("id"),
|
||||
exc,
|
||||
)
|
||||
except Exception:
|
||||
# tool dose not exist
|
||||
logger.exception(
|
||||
|
||||
@ -12,16 +12,21 @@ def init_app(app: DifyApp):
|
||||
clear_orphaned_file_records,
|
||||
convert_to_agent_apps,
|
||||
create_tenant,
|
||||
data_migrate,
|
||||
delete_archived_workflow_runs,
|
||||
export_app_messages,
|
||||
export_migration_data,
|
||||
export_migration_data_template,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
file_usage,
|
||||
fix_app_site_missing,
|
||||
import_migration_data,
|
||||
install_plugins,
|
||||
install_rag_pipeline_plugins,
|
||||
migrate_data_for_plugin,
|
||||
migrate_oss,
|
||||
migration_data_wizard,
|
||||
old_metadata_migration,
|
||||
remove_orphaned_files_on_storage,
|
||||
reset_email,
|
||||
@ -44,6 +49,7 @@ def init_app(app: DifyApp):
|
||||
convert_to_agent_apps,
|
||||
add_qdrant_index,
|
||||
create_tenant,
|
||||
data_migrate,
|
||||
upgrade_db,
|
||||
fix_app_site_missing,
|
||||
migrate_data_for_plugin,
|
||||
@ -68,6 +74,10 @@ def init_app(app: DifyApp):
|
||||
clean_workflow_runs,
|
||||
clean_expired_messages,
|
||||
export_app_messages,
|
||||
export_migration_data,
|
||||
export_migration_data_template,
|
||||
import_migration_data,
|
||||
migration_data_wizard,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
app.cli.add_command(cmd)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import posixpath
|
||||
from collections.abc import Generator
|
||||
from typing import override
|
||||
|
||||
import oss2 as aliyun_s3
|
||||
|
||||
@ -29,9 +30,11 @@ class AliyunOssStorage(BaseStorage):
|
||||
cloudbox_id=dify_config.ALIYUN_CLOUDBOX_ID,
|
||||
)
|
||||
|
||||
@override
|
||||
def save(self, filename, data):
|
||||
self.client.put_object(self.__wrapper_folder_filename(filename), data)
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
|
||||
data = obj.read()
|
||||
@ -39,17 +42,21 @@ class AliyunOssStorage(BaseStorage):
|
||||
return b""
|
||||
return data
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
|
||||
while chunk := obj.read(4096):
|
||||
yield chunk
|
||||
|
||||
@override
|
||||
def download(self, filename: str, target_filepath):
|
||||
self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath)
|
||||
|
||||
@override
|
||||
def exists(self, filename: str):
|
||||
return self.client.object_exists(self.__wrapper_folder_filename(filename))
|
||||
|
||||
@override
|
||||
def delete(self, filename: str):
|
||||
self.client.delete_object(self.__wrapper_folder_filename(filename))
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import override
|
||||
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
@ -48,9 +49,11 @@ class AwsS3Storage(BaseStorage):
|
||||
# other error, raise exception
|
||||
raise
|
||||
|
||||
@override
|
||||
def save(self, filename, data):
|
||||
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
try:
|
||||
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
||||
@ -61,6 +64,7 @@ class AwsS3Storage(BaseStorage):
|
||||
raise
|
||||
return data
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
try:
|
||||
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||
@ -73,9 +77,11 @@ class AwsS3Storage(BaseStorage):
|
||||
else:
|
||||
raise
|
||||
|
||||
@override
|
||||
def download(self, filename, target_filepath):
|
||||
self.client.download_file(self.bucket_name, filename, target_filepath)
|
||||
|
||||
@override
|
||||
def exists(self, filename):
|
||||
try:
|
||||
self.client.head_object(Bucket=self.bucket_name, Key=filename)
|
||||
@ -83,5 +89,6 @@ class AwsS3Storage(BaseStorage):
|
||||
except:
|
||||
return False
|
||||
|
||||
@override
|
||||
def delete(self, filename: str):
|
||||
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from typing import override
|
||||
|
||||
from azure.identity import ChainedTokenCredential, DefaultAzureCredential
|
||||
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
|
||||
@ -26,6 +27,7 @@ class AzureBlobStorage(BaseStorage):
|
||||
else:
|
||||
self.credential = None
|
||||
|
||||
@override
|
||||
def save(self, filename, data):
|
||||
if not self.bucket_name:
|
||||
return
|
||||
@ -34,6 +36,7 @@ class AzureBlobStorage(BaseStorage):
|
||||
blob_container = client.get_container_client(container=self.bucket_name)
|
||||
blob_container.upload_blob(filename, data)
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
if not self.bucket_name:
|
||||
raise FileNotFoundError("Azure bucket name is not configured.")
|
||||
@ -46,6 +49,7 @@ class AzureBlobStorage(BaseStorage):
|
||||
raise TypeError(f"Expected bytes from blob.readall(), got {type(data).__name__}")
|
||||
return data
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
if not self.bucket_name:
|
||||
raise FileNotFoundError("Azure bucket name is not configured.")
|
||||
@ -55,6 +59,7 @@ class AzureBlobStorage(BaseStorage):
|
||||
blob_data = blob.download_blob()
|
||||
yield from blob_data.chunks()
|
||||
|
||||
@override
|
||||
def download(self, filename, target_filepath):
|
||||
if not self.bucket_name:
|
||||
return
|
||||
@ -66,6 +71,7 @@ class AzureBlobStorage(BaseStorage):
|
||||
blob_data = blob.download_blob()
|
||||
blob_data.readinto(my_blob)
|
||||
|
||||
@override
|
||||
def exists(self, filename):
|
||||
if not self.bucket_name:
|
||||
return False
|
||||
@ -75,6 +81,7 @@ class AzureBlobStorage(BaseStorage):
|
||||
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
return blob.exists()
|
||||
|
||||
@override
|
||||
def delete(self, filename: str):
|
||||
if not self.bucket_name:
|
||||
return
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user