mirror of
https://github.com/langgenius/dify.git
synced 2026-06-05 00:16:25 +08:00
Compare commits
73 Commits
feat/dev-s
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| a8f009a965 | |||
| 0bfbd2061e | |||
| c8abb11bf0 | |||
| f9320b2c91 | |||
| f0fd7ddb60 | |||
| b77f5f1e4a | |||
| b67c3a5f76 | |||
| 5b5a06136a | |||
| 6e3c9597ff | |||
| 3c98f96ae8 | |||
| 44725dde74 | |||
| d3058d63bd | |||
| 4fc62d3b38 | |||
| e14cb209a4 | |||
| bb3c9929f9 | |||
| 35a55813d2 | |||
| a247d625e5 | |||
| 5c7f05bd10 | |||
| 02e1a60cde | |||
| 57b573d02b | |||
| 9de40e8f21 | |||
| cad0942f4d | |||
| cb9b1b593e | |||
| 2a8bdc2373 | |||
| ee6a07d13c | |||
| 2d6c9300e3 | |||
| d6b4c800c2 | |||
| 1b37635f92 | |||
| 86af36429d | |||
| b96ea94505 | |||
| d649cccda0 | |||
| 5cbbd78f38 | |||
| 5a0ad4ecd9 | |||
| 1e76b9e1b8 | |||
| 1b972c4e09 | |||
| 7968d2c3c8 | |||
| 7507e9ba67 | |||
| ca31762e26 | |||
| f591da7865 | |||
| f19679b217 | |||
| b682591c7a | |||
| 8f6b59feff | |||
| 99833f65d8 | |||
| 696fc5c213 | |||
| eae44cfecb | |||
| dea4e66456 | |||
| 3cd0da303a | |||
| 888483a2f8 | |||
| 7056985f72 | |||
| 6ce61eae59 | |||
| 079af312c6 | |||
| 0da13dfe4d | |||
| 1ff4d75084 | |||
| e35d23c3cb | |||
| e530e84772 | |||
| 2257a4f1ef | |||
| f465dc5090 | |||
| 5c1cfe6ada | |||
| 8d401d84c7 | |||
| b74287c2ab | |||
| c64d3e98c4 | |||
| a3265f722e | |||
| 5658065b97 | |||
| 8fc2807194 | |||
| fc7716704d | |||
| 71ffaacb58 | |||
| cfc1cf2b8c | |||
| 055d9b9f0a | |||
| 21711bebeb | |||
| becccbf288 | |||
| 86497045c9 | |||
| 687a177b24 | |||
| 4a6d278354 |
@ -1 +0,0 @@
|
||||
../../.agents/skills/frontend-query-mutation
|
||||
1
.claude/skills/how-to-write-component
Symbolic link
1
.claude/skills/how-to-write-component
Symbolic link
@ -0,0 +1 @@
|
||||
../../.agents/skills/how-to-write-component
|
||||
20
.github/workflows/autofix.yml
vendored
20
.github/workflows/autofix.yml
vendored
@ -51,6 +51,15 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
- name: Check dify-agent inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: dify-agent-changes
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
dify-agent/**/*.py
|
||||
dify-agent/pyproject.toml
|
||||
dify-agent/uv.lock
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
|
||||
with:
|
||||
@ -76,6 +85,17 @@ jobs:
|
||||
# Format code
|
||||
uv run ruff format ..
|
||||
|
||||
- if: github.event_name != 'merge_group' && steps.dify-agent-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd dify-agent
|
||||
uv sync --dev
|
||||
# fmt first to avoid line too long
|
||||
uv run ruff format .
|
||||
# Fix lint errors
|
||||
uv run ruff check --fix .
|
||||
# Format code
|
||||
uv run ruff format .
|
||||
|
||||
- name: count migration progress
|
||||
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -259,3 +259,6 @@ scripts/stress-test/reports/
|
||||
.qoder/*
|
||||
.context/
|
||||
.eslintcache
|
||||
|
||||
# Vitest local reports
|
||||
web/.vitest-reports/
|
||||
|
||||
@ -17,7 +17,7 @@ FROM base AS packages
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
g++ \
|
||||
git g++ \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
@ -97,7 +97,6 @@ RUN \
|
||||
# Copy Python environment and packages
|
||||
ENV VIRTUAL_ENV=/app/api/.venv
|
||||
COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
COPY --from=packages --chown=dify:dify /app/dify-agent /app/dify-agent
|
||||
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
||||
|
||||
# Download nltk data
|
||||
|
||||
@ -34,6 +34,7 @@ from clients.agent_backend.request_builder import (
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID,
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID,
|
||||
AgentBackendAgentAppRunInput,
|
||||
AgentBackendModelConfig,
|
||||
AgentBackendOutputConfig,
|
||||
AgentBackendRunRequestBuilder,
|
||||
@ -49,6 +50,7 @@ __all__ = [
|
||||
"DIFY_PLUGIN_TOOLS_LAYER_ID",
|
||||
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
|
||||
"WORKFLOW_USER_PROMPT_LAYER_ID",
|
||||
"AgentBackendAgentAppRunInput",
|
||||
"AgentBackendError",
|
||||
"AgentBackendHTTPError",
|
||||
"AgentBackendInternalEvent",
|
||||
|
||||
@ -30,6 +30,7 @@ from dify_agent.layers.execution_context import (
|
||||
DifyExecutionContextLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
|
||||
from dify_agent.layers.shell import DIFY_SHELL_LAYER_TYPE_ID, DifyShellLayerConfig
|
||||
from dify_agent.protocol import (
|
||||
DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
DIFY_AGENT_MODEL_LAYER_ID,
|
||||
@ -45,8 +46,10 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
|
||||
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
|
||||
AGENT_APP_USER_PROMPT_LAYER_ID = "agent_app_user_prompt"
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID = "tools"
|
||||
DIFY_SHELL_LAYER_ID = "shell"
|
||||
|
||||
# Layer types that hold credentials in their per-run config. These are excluded
|
||||
# from the cleanup-replay composition (and from the snapshot that is sent with
|
||||
@ -166,6 +169,10 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
idempotency_key: str | None = None
|
||||
output: AgentBackendOutputConfig | None = None
|
||||
tools: DifyPluginToolsLayerConfig | None = None
|
||||
# Inject the sandboxed shell layer (dify.shell). Requires the agent backend
|
||||
# to be wired with a shellctl entrypoint; see configs AGENT_SHELL_ENABLED.
|
||||
include_shell: bool = False
|
||||
shell_config: DifyShellLayerConfig | None = None
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
include_history: bool = True
|
||||
suspend_on_exit: bool = True
|
||||
@ -181,9 +188,154 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
class AgentBackendAgentAppRunInput(BaseModel):
|
||||
"""Inputs to build one Agent App conversation-turn run request.
|
||||
|
||||
Unlike the workflow-node input there is no workflow-node-job prompt and no
|
||||
previous-node context: the user prompt is the chat message, and multi-turn
|
||||
continuity comes from ``session_snapshot`` + the history layer keyed by the
|
||||
conversation.
|
||||
"""
|
||||
|
||||
model: AgentBackendModelConfig
|
||||
execution_context: DifyExecutionContextLayerConfig
|
||||
user_prompt: str
|
||||
agent_soul_prompt: str | None = None
|
||||
purpose: RunPurpose = "agent_app"
|
||||
idempotency_key: str | None = None
|
||||
output: AgentBackendOutputConfig | None = None
|
||||
tools: DifyPluginToolsLayerConfig | None = None
|
||||
# Inject the sandboxed shell layer (dify.shell). Requires the agent backend
|
||||
# to be wired with a shellctl entrypoint; see configs AGENT_SHELL_ENABLED.
|
||||
include_shell: bool = False
|
||||
shell_config: DifyShellLayerConfig | None = None
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
include_history: bool = True
|
||||
suspend_on_exit: bool = True
|
||||
metadata: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
|
||||
@field_validator("user_prompt")
|
||||
@classmethod
|
||||
def _reject_blank_prompt(cls, value: str) -> str:
|
||||
if not value.strip():
|
||||
raise ValueError("prompt must not be blank")
|
||||
return value
|
||||
|
||||
|
||||
class AgentBackendRunRequestBuilder:
|
||||
"""Converts API product state into the public ``dify-agent`` run protocol."""
|
||||
|
||||
def build_for_agent_app(self, run_input: AgentBackendAgentAppRunInput) -> CreateRunRequest:
|
||||
"""Build an Agent App conversation-turn run request.
|
||||
|
||||
Layer graph: optional Agent Soul system prompt → user prompt →
|
||||
execution context → optional history (multi-turn) → LLM → optional
|
||||
plugin tools → optional structured output. Mirrors the workflow-node
|
||||
layer ordering minus the workflow-job / previous-node prompt.
|
||||
"""
|
||||
layers: list[RunLayerSpec] = []
|
||||
if run_input.agent_soul_prompt:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=AGENT_SOUL_PROMPT_LAYER_ID,
|
||||
type=PLAIN_PROMPT_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_soul"},
|
||||
config=PromptLayerConfig(prefix=run_input.agent_soul_prompt),
|
||||
)
|
||||
)
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
RunLayerSpec(
|
||||
name=AGENT_APP_USER_PROMPT_LAYER_ID,
|
||||
type=PLAIN_PROMPT_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_app_user_prompt"},
|
||||
config=PromptLayerConfig(user=run_input.user_prompt),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
|
||||
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.execution_context,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if run_input.include_history:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_session_history"},
|
||||
)
|
||||
)
|
||||
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_MODEL_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=DifyPluginLLMLayerConfig(
|
||||
plugin_id=run_input.model.plugin_id,
|
||||
model_provider=run_input.model.model_provider,
|
||||
model=run_input.model.model,
|
||||
credentials=run_input.model.credentials,
|
||||
model_settings=run_input.model.model_settings or None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.tools is not None and run_input.tools.tools:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_PLUGIN_TOOLS_LAYER_ID,
|
||||
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.tools,
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.include_shell:
|
||||
# Sandboxed bash workspace (dify.shell). The layer declares NoLayerDeps,
|
||||
# so the spec carries no deps; shellctl connection is server-injected.
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_SHELL_LAYER_ID,
|
||||
type=DIFY_SHELL_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.shell_config or DifyShellLayerConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.output is not None:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
type=DIFY_OUTPUT_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=DifyOutputLayerConfig(
|
||||
json_schema=run_input.output.json_schema,
|
||||
description=run_input.output.description,
|
||||
strict=run_input.output.strict,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return CreateRunRequest(
|
||||
composition=RunComposition(layers=layers),
|
||||
purpose=run_input.purpose,
|
||||
idempotency_key=run_input.idempotency_key,
|
||||
metadata=run_input.metadata,
|
||||
session_snapshot=run_input.session_snapshot,
|
||||
on_exit=LayerExitSignals(
|
||||
default=ExitIntent.SUSPEND if run_input.suspend_on_exit else ExitIntent.DELETE,
|
||||
),
|
||||
)
|
||||
|
||||
def build_cleanup_request(
|
||||
self,
|
||||
*,
|
||||
@ -302,6 +454,18 @@ class AgentBackendRunRequestBuilder:
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.include_shell:
|
||||
# Sandboxed bash workspace (dify.shell). The layer declares NoLayerDeps,
|
||||
# so the spec carries no deps; shellctl connection is server-injected.
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_SHELL_LAYER_ID,
|
||||
type=DIFY_SHELL_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.shell_config or DifyShellLayerConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.output is not None:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
|
||||
135
api/clients/agent_backend/workspace_files_client.py
Normal file
135
api/clients/agent_backend/workspace_files_client.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""API-side client for the agent backend's read-only workspace file endpoints.
|
||||
|
||||
The agent backend exposes ``/workspaces/{session_id}/files{,/preview,/download}``
|
||||
to inspect a shell-layer sandbox workspace. This thin synchronous client proxies
|
||||
those reads for the console FS inspector and normalizes transport/HTTP failures
|
||||
into the API backend's ``AgentBackendError`` boundary, preserving the backend's
|
||||
status code and ``{code, message}`` detail so the controller can relay them.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from clients.agent_backend.errors import AgentBackendHTTPError, AgentBackendTransportError
|
||||
|
||||
_DEFAULT_TIMEOUT_SECONDS = 30.0
|
||||
|
||||
|
||||
class WorkspaceFileEntry(BaseModel):
|
||||
"""One entry in a workspace directory listing."""
|
||||
|
||||
name: str
|
||||
type: Literal["file", "dir", "symlink"]
|
||||
size: int
|
||||
mtime: int
|
||||
|
||||
|
||||
class WorkspaceListResult(BaseModel):
|
||||
"""Directory listing of a workspace path."""
|
||||
|
||||
path: str
|
||||
entries: list[WorkspaceFileEntry]
|
||||
truncated: bool
|
||||
|
||||
|
||||
class WorkspacePreviewResult(BaseModel):
|
||||
"""Inline preview of a workspace file."""
|
||||
|
||||
path: str
|
||||
size: int
|
||||
truncated: bool
|
||||
binary: bool
|
||||
text: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class WorkspaceDownloadResult:
|
||||
"""Decoded bytes of a workspace file for download."""
|
||||
|
||||
path: str
|
||||
size: int
|
||||
truncated: bool
|
||||
content: bytes
|
||||
|
||||
|
||||
class WorkspaceFilesBackendClient:
|
||||
"""Synchronous proxy to the agent backend workspace file endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
*,
|
||||
timeout: float = _DEFAULT_TIMEOUT_SECONDS,
|
||||
transport: httpx.BaseTransport | None = None,
|
||||
) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._transport = transport
|
||||
|
||||
def list_files(self, session_id: str, path: str) -> WorkspaceListResult:
|
||||
data = self._get(f"/workspaces/{session_id}/files", params={"path": path})
|
||||
return WorkspaceListResult.model_validate(data)
|
||||
|
||||
def preview(self, session_id: str, path: str) -> WorkspacePreviewResult:
|
||||
data = self._get(f"/workspaces/{session_id}/files/preview", params={"path": path})
|
||||
return WorkspacePreviewResult.model_validate(data)
|
||||
|
||||
def download(self, session_id: str, path: str) -> WorkspaceDownloadResult:
|
||||
data = self._get(f"/workspaces/{session_id}/files/download", params={"path": path})
|
||||
encoded = data.get("content_base64")
|
||||
if not isinstance(encoded, str):
|
||||
raise AgentBackendHTTPError("agent backend download response missing content", status_code=502, detail=data)
|
||||
try:
|
||||
content = base64.b64decode(encoded, validate=True)
|
||||
except (binascii.Error, ValueError) as exc:
|
||||
raise AgentBackendHTTPError(
|
||||
"agent backend returned undecodable download content", status_code=502, detail=str(exc)
|
||||
) from exc
|
||||
size = data.get("size")
|
||||
return WorkspaceDownloadResult(
|
||||
path=str(data.get("path", path)),
|
||||
size=int(size) if isinstance(size, (int, float)) else len(content),
|
||||
truncated=bool(data.get("truncated")),
|
||||
content=content,
|
||||
)
|
||||
|
||||
def _get(self, route: str, *, params: dict[str, str]) -> dict[str, object]:
|
||||
url = f"{self._base_url}{route}"
|
||||
try:
|
||||
with httpx.Client(timeout=self._timeout, transport=self._transport, trust_env=False) as client:
|
||||
response = client.get(url, params=params)
|
||||
except httpx.HTTPError as exc:
|
||||
raise AgentBackendTransportError(f"failed to reach agent backend workspace endpoint: {exc}") from exc
|
||||
if response.status_code >= 400:
|
||||
detail: object
|
||||
try:
|
||||
detail = response.json().get("detail", response.text)
|
||||
except ValueError:
|
||||
detail = response.text
|
||||
raise AgentBackendHTTPError(
|
||||
f"agent backend workspace request failed ({response.status_code})",
|
||||
status_code=response.status_code,
|
||||
detail=detail,
|
||||
)
|
||||
body = response.json()
|
||||
if not isinstance(body, dict):
|
||||
raise AgentBackendHTTPError(
|
||||
"agent backend workspace response was not an object", status_code=502, detail=body
|
||||
)
|
||||
return body
|
||||
|
||||
|
||||
__all__ = [
|
||||
"WorkspaceDownloadResult",
|
||||
"WorkspaceFileEntry",
|
||||
"WorkspaceFilesBackendClient",
|
||||
"WorkspaceListResult",
|
||||
"WorkspacePreviewResult",
|
||||
]
|
||||
@ -29,6 +29,7 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
current_state = self.current_state
|
||||
remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")
|
||||
|
||||
@ -21,3 +21,13 @@ class AgentBackendConfig(BaseSettings):
|
||||
description="Scenario used by the fake Agent backend client.",
|
||||
default="success",
|
||||
)
|
||||
|
||||
AGENT_SHELL_ENABLED: bool = Field(
|
||||
description=(
|
||||
"Inject the dify.shell layer (sandboxed bash workspace) into Agent runs. "
|
||||
"Requires the agent backend to be wired with a shellctl entrypoint; keep it "
|
||||
"off until shellctl is deployed, otherwise every agent run that includes the "
|
||||
"shell layer will fail."
|
||||
),
|
||||
default=False,
|
||||
)
|
||||
|
||||
@ -81,4 +81,15 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
||||
},
|
||||
},
|
||||
},
|
||||
# agent default mode (new Agent App type). The runtime model / prompt / tools
|
||||
# come from the bound Agent Soul snapshot, so no model_config is seeded in the
|
||||
# template; create_app still creates a model-less app_model_config row to hold
|
||||
# app-level presentation features (opener, follow-up, citations, ...).
|
||||
AppMode.AGENT: {
|
||||
"app": {
|
||||
"mode": AppMode.AGENT,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -1,10 +1,40 @@
|
||||
import json
|
||||
|
||||
from pydantic import BaseModel, JsonValue
|
||||
from pydantic import BaseModel, Field, JsonValue
|
||||
|
||||
HUMAN_INPUT_FORM_INPUT_EXAMPLE = {
|
||||
"decision": "approve",
|
||||
"attachment": {
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": "4e0d1b87-52f2-49f6-b8c6-95cd9c954b3e",
|
||||
"type": "document",
|
||||
},
|
||||
"attachments": [
|
||||
{
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": "1a77f0df-c0e6-461c-987c-e72526f341ee",
|
||||
"type": "document",
|
||||
},
|
||||
{
|
||||
"transfer_method": "remote_url",
|
||||
"url": "https://example.com/report.pdf",
|
||||
"type": "document",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict[str, JsonValue]
|
||||
inputs: dict[str, JsonValue] = Field(
|
||||
description=(
|
||||
"Submitted human input values keyed by output variable name. "
|
||||
"Use a string for paragraph or select input values, a file mapping for file inputs, "
|
||||
"and a list of file mappings for file-list inputs. Local file mappings use "
|
||||
"`transfer_method=local_file` with `upload_file_id`; remote file mappings use "
|
||||
"`transfer_method=remote_url` with `url` or `remote_url`."
|
||||
),
|
||||
examples=[HUMAN_INPUT_FORM_INPUT_EXAMPLE],
|
||||
)
|
||||
action: str
|
||||
|
||||
|
||||
|
||||
@ -36,6 +36,8 @@ QueryParamDoc = TypedDict(
|
||||
},
|
||||
)
|
||||
|
||||
JsonResponseWithStatus = tuple[dict[str, Any], int]
|
||||
|
||||
|
||||
class QueryArgs(Protocol):
|
||||
def to_dict(self, flat: bool = True) -> dict[str, str]: ...
|
||||
|
||||
@ -51,6 +51,9 @@ from .agent import roster as agent_roster
|
||||
from .app import (
|
||||
advanced_prompt_template,
|
||||
agent,
|
||||
agent_app_access,
|
||||
agent_app_feature,
|
||||
agent_app_workspace,
|
||||
annotation,
|
||||
app,
|
||||
audio,
|
||||
@ -119,7 +122,6 @@ from .explore import (
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
from .snippets import snippet_workflow, snippet_workflow_draft_variable
|
||||
from .socketio import workflow as socketio_workflow
|
||||
|
||||
# Import tag controllers
|
||||
@ -135,7 +137,6 @@ from .workspace import (
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
snippets,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
@ -148,6 +149,9 @@ __all__ = [
|
||||
"activate",
|
||||
"advanced_prompt_template",
|
||||
"agent",
|
||||
"agent_app_access",
|
||||
"agent_app_feature",
|
||||
"agent_app_workspace",
|
||||
"agent_composer",
|
||||
"agent_providers",
|
||||
"agent_roster",
|
||||
@ -208,9 +212,6 @@ __all__ = [
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
"snippet_workflow",
|
||||
"snippet_workflow_draft_variable",
|
||||
"snippets",
|
||||
"socketio_workflow",
|
||||
"spec",
|
||||
"statistic",
|
||||
|
||||
@ -3,7 +3,13 @@ from flask_restx import Resource
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from fields.agent_fields import (
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerCandidatesResponse,
|
||||
@ -12,7 +18,7 @@ from fields.agent_fields import (
|
||||
WorkflowAgentComposerResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent.composer_service import AgentComposerService
|
||||
from services.agent.composer_validator import ComposerConfigValidator
|
||||
@ -38,8 +44,8 @@ class WorkflowAgentComposerApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App, node_id: str):
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
AgentComposerService.load_workflow_composer(
|
||||
@ -58,8 +64,9 @@ class WorkflowAgentComposerApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def put(self, app_model: App, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def put(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
@ -67,7 +74,7 @@ class WorkflowAgentComposerApi(Resource):
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account.id,
|
||||
account_id=account_id,
|
||||
payload=payload,
|
||||
),
|
||||
)
|
||||
@ -113,8 +120,8 @@ class WorkflowAgentComposerImpactApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None
|
||||
if not current_snapshot_id:
|
||||
@ -138,8 +145,9 @@ class WorkflowAgentComposerSaveToRosterApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
@ -147,7 +155,7 @@ class WorkflowAgentComposerSaveToRosterApi(Resource):
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account.id,
|
||||
account_id=account_id,
|
||||
payload=payload,
|
||||
),
|
||||
)
|
||||
@ -160,8 +168,8 @@ class AgentAppComposerApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
return dump_response(
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id),
|
||||
@ -174,15 +182,16 @@ class AgentAppComposerApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model()
|
||||
def put(self, app_model: App):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def put(self, tenant_id: str, account_id: str, app_model: App):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return dump_response(
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerService.save_agent_app_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
account_id=account.id,
|
||||
account_id=account_id,
|
||||
payload=payload,
|
||||
),
|
||||
)
|
||||
|
||||
@ -6,7 +6,13 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.agent_fields import (
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
@ -16,7 +22,7 @@ from fields.agent_fields import (
|
||||
AgentRosterResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from services.agent.roster_service import AgentRosterService
|
||||
from services.entities.agent_entities import RosterAgentCreatePayload, RosterAgentUpdatePayload, RosterListQuery
|
||||
|
||||
@ -58,8 +64,8 @@ class AgentRosterListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
query = RosterListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
return dump_response(
|
||||
AgentRosterListResponse,
|
||||
@ -74,11 +80,12 @@ class AgentRosterListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, account_id: str):
|
||||
payload = RosterAgentCreatePayload.model_validate(console_ns.payload or {})
|
||||
service = _agent_roster_service()
|
||||
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account.id, payload=payload)
|
||||
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account_id, payload=payload)
|
||||
return dump_response(
|
||||
AgentRosterResponse,
|
||||
service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id),
|
||||
@ -92,8 +99,8 @@ class AgentInviteOptionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
query = AgentInviteOptionsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
return dump_response(
|
||||
AgentInviteOptionsResponse,
|
||||
@ -113,8 +120,8 @@ class AgentRosterDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
return dump_response(
|
||||
AgentRosterResponse,
|
||||
_agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id)),
|
||||
@ -126,13 +133,14 @@ class AgentRosterDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, agent_id: UUID):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, account_id: str, agent_id: UUID):
|
||||
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
return dump_response(
|
||||
AgentRosterResponse,
|
||||
_agent_roster_service().update_roster_agent(
|
||||
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id, payload=payload
|
||||
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id, payload=payload
|
||||
),
|
||||
)
|
||||
|
||||
@ -141,9 +149,10 @@ class AgentRosterDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, agent_id: UUID):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def delete(self, tenant_id: str, account_id: str, agent_id: UUID):
|
||||
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id)
|
||||
return "", 204
|
||||
|
||||
|
||||
@ -153,8 +162,8 @@ class AgentRosterVersionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
return dump_response(
|
||||
AgentConfigSnapshotListResponse,
|
||||
{"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))},
|
||||
@ -167,8 +176,8 @@ class AgentRosterVersionDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID, version_id: UUID):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID, version_id: UUID):
|
||||
return dump_response(
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
_agent_roster_service().get_agent_version_detail(
|
||||
|
||||
59
api/controllers/console/app/agent_app_access.py
Normal file
59
api/controllers/console/app/agent_app_access.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""Agent App access & sharing endpoints (read-only workflow references).
|
||||
|
||||
An Agent App is backed by a roster Agent that workflow Agent nodes may also
|
||||
reference. This exposes the read-only "Workflow access" surface from the PRD:
|
||||
which workflow apps use this Agent, without leaking the workflows' internals.
|
||||
"""
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent.roster_service import AgentRosterService
|
||||
|
||||
|
||||
class AgentReferencingWorkflowResponse(ResponseModel):
|
||||
app_id: str
|
||||
app_name: str
|
||||
app_mode: str
|
||||
workflow_id: str
|
||||
node_ids: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentReferencingWorkflowsResponse(ResponseModel):
|
||||
data: list[AgentReferencingWorkflowResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_response_schema_models(console_ns, AgentReferencingWorkflowsResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-referencing-workflows")
|
||||
class AgentAppReferencingWorkflowsResource(Resource):
|
||||
@console_ns.doc("list_agent_app_referencing_workflows")
|
||||
@console_ns.doc(description="List workflow apps that reference this Agent App's bound Agent (read-only)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Referencing workflows listed successfully",
|
||||
console_ns.models[AgentReferencingWorkflowsResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
workflows = AgentRosterService(db.session).list_workflows_referencing_app_agent(
|
||||
tenant_id=tenant_id, app_id=app_model.id
|
||||
)
|
||||
return AgentReferencingWorkflowsResponse(
|
||||
data=[AgentReferencingWorkflowResponse.model_validate(workflow) for workflow in workflows]
|
||||
).model_dump(mode="json")
|
||||
93
api/controllers/console/app/agent_app_feature.py
Normal file
93
api/controllers/console/app/agent_app_feature.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Agent App presentation-feature configuration endpoint.
|
||||
|
||||
The new Agent App type keeps model / prompt / tools in its bound Agent Soul, so
|
||||
the legacy ``/model-config`` surface (which writes model, prompt and agent tool
|
||||
config) is the wrong place to configure its app-level presentation features.
|
||||
This endpoint exposes only the PRD "Misc Legacy" feature subset — conversation
|
||||
opener, follow-up suggestions, citations, content moderation and speech — and
|
||||
persists them onto the app's ``app_model_config`` without touching anything the
|
||||
Soul owns.
|
||||
"""
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.agent_config_entities import (
|
||||
AgentFeatureToggleConfig,
|
||||
AgentSensitiveWordAvoidanceFeatureConfig,
|
||||
AgentSuggestedQuestionsAfterAnswerFeatureConfig,
|
||||
AgentTextToSpeechFeatureConfig,
|
||||
)
|
||||
from models.model import App, AppMode
|
||||
from services.agent_app_feature_service import AgentAppFeatureConfigService
|
||||
|
||||
|
||||
class AgentAppFeaturesPayload(BaseModel):
|
||||
"""Presentation features configurable on an Agent App.
|
||||
|
||||
All fields are optional; an omitted field is reset to its disabled/empty
|
||||
default (the config form sends the full desired feature state on save).
|
||||
"""
|
||||
|
||||
opening_statement: str | None = Field(default=None, description="Conversation opener shown before the first turn")
|
||||
suggested_questions: list[str] | None = Field(
|
||||
default=None, description="Preset questions shown alongside the opener"
|
||||
)
|
||||
suggested_questions_after_answer: AgentSuggestedQuestionsAfterAnswerFeatureConfig | None = Field(
|
||||
default=None, description="Follow-up suggestions config, e.g. {'enabled': true}"
|
||||
)
|
||||
speech_to_text: AgentFeatureToggleConfig | None = Field(default=None, description="Speech-to-text config")
|
||||
text_to_speech: AgentTextToSpeechFeatureConfig | None = Field(default=None, description="Text-to-speech config")
|
||||
retriever_resource: AgentFeatureToggleConfig | None = Field(
|
||||
default=None, description="Citations / attributions config, e.g. {'enabled': true}"
|
||||
)
|
||||
sensitive_word_avoidance: AgentSensitiveWordAvoidanceFeatureConfig | None = Field(
|
||||
default=None, description="Content moderation config"
|
||||
)
|
||||
|
||||
|
||||
register_schema_models(console_ns, AgentAppFeaturesPayload)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-features")
|
||||
class AgentAppFeatureConfigResource(Resource):
|
||||
@console_ns.doc("update_agent_app_features")
|
||||
@console_ns.doc(description="Update an Agent App's presentation features (opener, follow-up, citations, ...)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AgentAppFeaturesPayload.__name__])
|
||||
@console_ns.response(200, "Features updated successfully", console_ns.models[SimpleResultResponse.__name__])
|
||||
@console_ns.response(400, "Invalid configuration")
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args = AgentAppFeaturesPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
new_app_model_config = AgentAppFeatureConfigService.update_features(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
config=args.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||
|
||||
return dump_response(SimpleResultResponse, {"result": "success"})
|
||||
319
api/controllers/console/app/agent_app_workspace.py
Normal file
319
api/controllers/console/app/agent_app_workspace.py
Normal file
@ -0,0 +1,319 @@
|
||||
"""Agent App sandbox file-system inspector (read-only).
|
||||
|
||||
Exposes the PRD "rc1-like sandbox file system, downloadable not editable" view
|
||||
for an Agent App conversation: list a directory, preview a file, or download a
|
||||
file from the conversation's shell-layer workspace. The API never touches
|
||||
shellctl directly — it resolves the conversation's sandbox ``session_id`` from
|
||||
the stored session snapshot and proxies to the agent backend's read-only
|
||||
workspace endpoints.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from clients.agent_backend.errors import AgentBackendHTTPError, AgentBackendTransportError
|
||||
from clients.agent_backend.workspace_files_client import WorkspaceDownloadResult
|
||||
from controllers.common.schema import (
|
||||
query_params_from_model,
|
||||
query_params_from_request,
|
||||
register_response_schema_models,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent_app_workspace_service import (
|
||||
AgentAppWorkspaceService,
|
||||
AgentWorkspaceInspectorError,
|
||||
WorkflowAgentWorkspaceService,
|
||||
)
|
||||
|
||||
|
||||
class _WorkspaceFileDownloadField(fields.Raw):
|
||||
__schema_type__ = "string"
|
||||
__schema_format__ = "binary"
|
||||
|
||||
|
||||
class AgentWorkspaceListQuery(BaseModel):
|
||||
conversation_id: str = Field(min_length=1, description="Agent App conversation ID")
|
||||
path: str = Field(default=".", description="Directory path relative to the sandbox workspace")
|
||||
|
||||
|
||||
class AgentWorkspaceFileQuery(BaseModel):
|
||||
conversation_id: str = Field(min_length=1, description="Agent App conversation ID")
|
||||
path: str = Field(min_length=1, description="File path relative to the sandbox workspace")
|
||||
|
||||
|
||||
class WorkflowAgentWorkspaceListQuery(BaseModel):
|
||||
path: str = Field(default=".", description="Directory path relative to the sandbox workspace")
|
||||
node_execution_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional workflow node execution ID. When omitted, the latest active session for the node is used."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAgentWorkspaceFileQuery(BaseModel):
|
||||
path: str = Field(min_length=1, description="File path relative to the sandbox workspace")
|
||||
node_execution_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional workflow node execution ID. When omitted, the latest active session for the node is used."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceFileEntryResponse(ResponseModel):
|
||||
name: str
|
||||
type: Literal["file", "dir", "symlink"]
|
||||
size: int
|
||||
mtime: int
|
||||
|
||||
|
||||
class WorkspaceListResponse(ResponseModel):
|
||||
path: str
|
||||
entries: list[WorkspaceFileEntryResponse] = Field(default_factory=list)
|
||||
truncated: bool = False
|
||||
|
||||
|
||||
class WorkspacePreviewResponse(ResponseModel):
|
||||
path: str
|
||||
size: int
|
||||
truncated: bool
|
||||
binary: bool
|
||||
text: str | None = None
|
||||
|
||||
|
||||
register_response_schema_models(console_ns, WorkspaceListResponse)
|
||||
register_response_schema_models(console_ns, WorkspacePreviewResponse)
|
||||
|
||||
|
||||
def _handle(exc: Exception) -> tuple[dict[str, object], int]:
|
||||
if isinstance(exc, AgentWorkspaceInspectorError):
|
||||
return {"code": exc.code, "message": exc.message}, exc.status_code
|
||||
if isinstance(exc, AgentBackendHTTPError):
|
||||
detail = exc.detail
|
||||
if isinstance(detail, dict):
|
||||
return {
|
||||
"code": detail.get("code", "agent_backend_error"),
|
||||
"message": detail.get("message", str(exc)),
|
||||
}, exc.status_code
|
||||
return {"code": "agent_backend_error", "message": str(detail)}, exc.status_code
|
||||
if isinstance(exc, AgentBackendTransportError):
|
||||
return {"code": "agent_backend_unreachable", "message": str(exc)}, 502
|
||||
raise exc
|
||||
|
||||
|
||||
def _download_response(result: WorkspaceDownloadResult) -> Response | tuple[dict[str, object], int]:
|
||||
if result.truncated:
|
||||
return {
|
||||
"code": "workspace_file_too_large",
|
||||
"message": (
|
||||
"file exceeds the workspace download limit; use preview for partial text or download a smaller file"
|
||||
),
|
||||
"size": result.size,
|
||||
}, 413
|
||||
filename = result.path.rsplit("/", 1)[-1] or "download"
|
||||
return Response(
|
||||
result.content,
|
||||
mimetype="application/octet-stream",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
"Content-Length": str(len(result.content)),
|
||||
"X-Workspace-File-Size": str(result.size),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-workspace/files")
|
||||
class AgentAppWorkspaceListResource(Resource):
|
||||
@console_ns.doc("list_agent_app_workspace_files")
|
||||
@console_ns.doc(description="List a directory in an Agent App conversation's sandbox workspace (read-only)")
|
||||
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentWorkspaceListQuery)})
|
||||
@console_ns.response(200, "Listing returned", console_ns.models[WorkspaceListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(AgentWorkspaceListQuery)
|
||||
try:
|
||||
result = AgentAppWorkspaceService().list_files(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
conversation_id=query.conversation_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-workspace/files/preview")
|
||||
class AgentAppWorkspacePreviewResource(Resource):
|
||||
@console_ns.doc("preview_agent_app_workspace_file")
|
||||
@console_ns.doc(description="Preview a text/binary file in an Agent App conversation's sandbox workspace")
|
||||
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentWorkspaceFileQuery)})
|
||||
@console_ns.response(200, "Preview returned", console_ns.models[WorkspacePreviewResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(AgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = AgentAppWorkspaceService().preview(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
conversation_id=query.conversation_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-workspace/files/download")
|
||||
class AgentAppWorkspaceDownloadResource(Resource):
|
||||
@console_ns.doc("download_agent_app_workspace_file")
|
||||
@console_ns.doc(description="Download a file from an Agent App conversation's sandbox workspace (read-only)")
|
||||
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentWorkspaceFileQuery)})
|
||||
@console_ns.doc(produces=["application/octet-stream"])
|
||||
@console_ns.response(200, "File bytes", _WorkspaceFileDownloadField)
|
||||
@console_ns.response(413, "File exceeds the workspace download limit")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(AgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = AgentAppWorkspaceService().download(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
conversation_id=query.conversation_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return _download_response(result)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflow-runs/<uuid:workflow_run_id>/agent-nodes/<string:node_id>/workspace/files"
|
||||
)
|
||||
class WorkflowAgentWorkspaceListResource(Resource):
|
||||
@console_ns.doc("list_workflow_agent_workspace_files")
|
||||
@console_ns.doc(description="List a directory in a Workflow Agent node's sandbox workspace (read-only)")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"workflow_run_id": "Workflow run ID",
|
||||
"node_id": "Workflow Agent node ID",
|
||||
**query_params_from_model(WorkflowAgentWorkspaceListQuery),
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Listing returned", console_ns.models[WorkspaceListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(WorkflowAgentWorkspaceListQuery)
|
||||
try:
|
||||
result = WorkflowAgentWorkspaceService().list_files(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
node_id=node_id,
|
||||
node_execution_id=query.node_execution_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflow-runs/<uuid:workflow_run_id>/agent-nodes/<string:node_id>/workspace/files/preview"
|
||||
)
|
||||
class WorkflowAgentWorkspacePreviewResource(Resource):
|
||||
@console_ns.doc("preview_workflow_agent_workspace_file")
|
||||
@console_ns.doc(description="Preview a text/binary file in a Workflow Agent node's sandbox workspace")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"workflow_run_id": "Workflow run ID",
|
||||
"node_id": "Workflow Agent node ID",
|
||||
**query_params_from_model(WorkflowAgentWorkspaceFileQuery),
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Preview returned", console_ns.models[WorkspacePreviewResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = WorkflowAgentWorkspaceService().preview(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
node_id=node_id,
|
||||
node_execution_id=query.node_execution_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflow-runs/<uuid:workflow_run_id>/agent-nodes/<string:node_id>/workspace/files/download"
|
||||
)
|
||||
class WorkflowAgentWorkspaceDownloadResource(Resource):
|
||||
@console_ns.doc("download_workflow_agent_workspace_file")
|
||||
@console_ns.doc(description="Download a file from a Workflow Agent node's sandbox workspace (read-only)")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"workflow_run_id": "Workflow run ID",
|
||||
"node_id": "Workflow Agent node ID",
|
||||
**query_params_from_model(WorkflowAgentWorkspaceFileQuery),
|
||||
}
|
||||
)
|
||||
@console_ns.doc(produces=["application/octet-stream"])
|
||||
@console_ns.response(200, "File bytes", _WorkspaceFileDownloadField)
|
||||
@console_ns.response(413, "File exceeds the workspace download limit")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = WorkflowAgentWorkspaceService().download(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
node_id=node_id,
|
||||
node_execution_id=query.node_execution_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return _download_response(result)
|
||||
@ -25,6 +25,9 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
with_current_user_id,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
@ -34,8 +37,8 @@ from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from libs.helper import build_icon_url, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from libs.login import login_required
|
||||
from models import Account, App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppListParams, AppService, CreateAppParams
|
||||
@ -55,7 +58,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
@ -66,7 +69,7 @@ _TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
@ -115,7 +118,9 @@ def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str,
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
mode: Literal["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"] = Field(
|
||||
..., description="App mode"
|
||||
)
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
@ -393,6 +398,8 @@ class AppDetailWithSite(AppDetail):
|
||||
max_active_requests: int | None = None
|
||||
deleted_tools: list[DeletedTool] = Field(default_factory=list)
|
||||
site: Site | None = None
|
||||
# For Agent App type: the roster Agent backing this app (None otherwise).
|
||||
bound_agent_id: str | None = None
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
@ -468,10 +475,10 @@ class AppListApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@with_session(write=False)
|
||||
def get(self, session: Session):
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user_id: str, session: Session):
|
||||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
|
||||
params = AppListParams(
|
||||
page=args.page,
|
||||
@ -484,7 +491,7 @@ class AppListApi(Resource):
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, params)
|
||||
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params)
|
||||
if not app_pagination:
|
||||
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json"), 200
|
||||
@ -544,9 +551,10 @@ class AppListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
"""Create app"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = CreateAppPayload.model_validate(console_ns.payload)
|
||||
params = CreateAppParams(
|
||||
name=args.name,
|
||||
@ -649,11 +657,10 @@ class AppCopyApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""Copy app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
@ -732,7 +739,8 @@ class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, app_model: App):
|
||||
"""Publish app to Creators Platform"""
|
||||
from configs import dify_config
|
||||
from core.helper.creators import get_redirect_url, upload_dsl
|
||||
@ -740,13 +748,11 @@ class AppPublishToCreatorsPlatformApi(Resource):
|
||||
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
return {"error": "Creators Platform features are not enabled"}, 403
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
|
||||
dsl_bytes = dsl_content.encode("utf-8")
|
||||
|
||||
claim_code = upload_dsl(dsl_bytes)
|
||||
redirect_url = get_redirect_url(str(current_user.id), claim_code)
|
||||
redirect_url = get_redirect_url(current_user_id, claim_code)
|
||||
|
||||
return {"redirect_url": redirect_url}
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Any, Literal
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
@ -19,7 +19,12 @@ from controllers.console.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user_id,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
@ -41,9 +46,24 @@ from services.errors.llm import InvokeRateLimitError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_debugger_chat_streaming(
|
||||
*, app_mode: AppMode, response_mode: str, response_mode_provided: bool = True
|
||||
) -> bool:
|
||||
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
|
||||
if app_mode != AppMode.AGENT:
|
||||
return response_mode != "blocking"
|
||||
if response_mode_provided and response_mode == "blocking":
|
||||
raise BadRequest("Agent App only supports streaming response mode.")
|
||||
return True
|
||||
|
||||
|
||||
class BaseMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config")
|
||||
# Agent Apps (AppMode.AGENT) derive their model + prompt from the bound Agent
|
||||
# Soul, so no override ``model_config`` is sent; chat / agent-chat / completion
|
||||
# debugging still pass it. Optional here, required in practice by those modes
|
||||
# downstream when their config is built from args.
|
||||
model_config_data: dict[str, Any] = Field(default_factory=dict, alias="model_config")
|
||||
files: list[Any] | None = Field(default=None, description="Uploaded files")
|
||||
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
|
||||
retriever_from: str = Field(default="dev", description="Retriever source")
|
||||
@ -131,14 +151,13 @@ class CompletionMessageStopApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model: App, task_id: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, app_model: App, task_id: str):
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
@ -157,13 +176,20 @@ class ChatMessageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
raw_payload = console_ns.payload or {}
|
||||
args_model = ChatMessagePayload.model_validate(raw_payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
streaming = _resolve_debugger_chat_streaming(
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
response_mode=args_model.response_mode,
|
||||
response_mode_provided=isinstance(raw_payload, dict) and "response_mode" in raw_payload,
|
||||
)
|
||||
if AppMode.value_of(app_model.mode) == AppMode.AGENT:
|
||||
args["response_mode"] = "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
@ -211,15 +237,14 @@ class ChatMessageStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, task_id: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, app_model: App, task_id: str):
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
|
||||
@ -212,7 +212,7 @@ class ChatConversationApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
@ -323,7 +323,7 @@ class ChatConversationDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
@ -340,7 +340,7 @@ class ChatConversationDetailApi(Resource):
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
@ -25,6 +26,7 @@ from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from services.workflow_generator_service import WorkflowGeneratorService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
@ -42,6 +44,24 @@ class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
class WorkflowGeneratePayload(BaseModel):
|
||||
"""Payload for the cmd+k `/create` and `/refine` workflow generator endpoint.
|
||||
|
||||
See ``services/workflow_generator_service.py`` for behaviour. Errors are
|
||||
surfaced through the same envelope as ``/rule-generate`` so the frontend
|
||||
can reuse its existing handler.
|
||||
"""
|
||||
|
||||
mode: Literal["workflow", "advanced-chat"] = Field(..., description="Target app mode for the generated graph")
|
||||
instruction: str = Field(..., description="Natural-language workflow description")
|
||||
ideal_output: str = Field(default="", description="Optional sample output for grounding")
|
||||
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
|
||||
current_graph: dict | None = Field(
|
||||
default=None,
|
||||
description="Existing draft graph to refine (cmd+k `/refine`); omit for create-from-scratch",
|
||||
)
|
||||
|
||||
|
||||
register_enum_models(console_ns, LLMMode)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
@ -50,6 +70,7 @@ register_schema_models(
|
||||
RuleStructuredOutputPayload,
|
||||
InstructionGeneratePayload,
|
||||
InstructionTemplatePayload,
|
||||
WorkflowGeneratePayload,
|
||||
ModelConfig,
|
||||
)
|
||||
|
||||
@ -265,3 +286,56 @@ class InstructionGenerationTemplateApi(Resource):
|
||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||
case _:
|
||||
raise ValueError(f"Invalid type: {args.type}")
|
||||
|
||||
|
||||
@console_ns.route("/workflow-generate")
|
||||
class WorkflowGenerateApi(Resource):
|
||||
"""Generate a Workflow / Chatflow draft graph from a natural-language description.
|
||||
|
||||
Triggered by the cmd+k `/create` slash command. Returns a graph payload
|
||||
shaped exactly like ``WorkflowService.sync_draft_workflow``'s input, so the
|
||||
frontend can hand it straight to ``/apps/{id}/workflows/draft``.
|
||||
"""
|
||||
|
||||
@console_ns.doc("generate_workflow_graph")
|
||||
@console_ns.doc(description="Generate a Dify workflow graph from natural language")
|
||||
@console_ns.expect(console_ns.models[WorkflowGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Workflow graph generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
args = WorkflowGeneratePayload.model_validate(console_ns.payload)
|
||||
|
||||
# Reject obviously-empty instructions at the boundary — Pydantic only
|
||||
# validates ``instruction`` is a str, but a whitespace-only string
|
||||
# would still hit the LLM and waste a planner+builder roundtrip on a
|
||||
# response that the postprocess validator would reject anyway.
|
||||
if not args.instruction.strip():
|
||||
return {
|
||||
"error": "Instruction is required",
|
||||
"errors": [{"code": "EMPTY_INSTRUCTION", "detail": "Instruction is required"}],
|
||||
}, 400
|
||||
|
||||
try:
|
||||
result = WorkflowGeneratorService.generate_workflow_graph(
|
||||
tenant_id=current_tenant_id,
|
||||
mode=args.mode,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
current_graph=args.current_graph,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
return result
|
||||
|
||||
@ -180,7 +180,7 @@ class ChatMessageListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
@ -337,7 +337,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, message_id: UUID):
|
||||
message_id_str = str(message_id)
|
||||
|
||||
@ -8,14 +8,20 @@ from pydantic import BaseModel, Field
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
@ -52,9 +58,10 @@ class ModelConfigResource(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model: App):
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user_id: str, app_model: App):
|
||||
"""Modify app model config"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -64,8 +71,8 @@ class ModelConfigResource(Resource):
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
updated_by=current_user.id,
|
||||
created_by=current_user_id,
|
||||
updated_by=current_user_id,
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
@ -90,7 +97,7 @@ class ModelConfigResource(Resource):
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -130,7 +137,7 @@ class ModelConfigResource(Resource):
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
@ -167,7 +174,7 @@ class ModelConfigResource(Resource):
|
||||
db.session.flush()
|
||||
|
||||
app_model.app_model_config_id = new_app_model_config.id
|
||||
app_model.updated_by = current_user.id
|
||||
app_model.updated_by = current_user_id
|
||||
app_model.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@ -290,7 +290,7 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, Concatenate, TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
@ -214,7 +214,9 @@ workflow_draft_variable_list_model = console_ns.model(
|
||||
)
|
||||
|
||||
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
def _api_prerequisite[T, **P, R](
|
||||
f: Callable[Concatenate[T, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@ -231,8 +233,8 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(*args, **kwargs)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@ -8,9 +8,16 @@ from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@ -32,8 +39,9 @@ class Subscription(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True))
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
|
||||
@ -45,8 +53,9 @@ class Invoices(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_invoices(current_user.email, current_tenant_id)
|
||||
|
||||
@ -63,9 +72,8 @@ class PartnerTenants(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def put(self, partner_key: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def put(self, current_user: Account, partner_key: str):
|
||||
try:
|
||||
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
|
||||
click_id = args.click_id
|
||||
|
||||
@ -3,11 +3,18 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from ..wraps import (
|
||||
account_initialization_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
|
||||
|
||||
class ComplianceDownloadQuery(BaseModel):
|
||||
@ -29,8 +36,9 @@ class ComplianceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
|
||||
@ -1,41 +1,37 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_serializer
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.fields import SimpleResultResponse, TextContentResponse
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_model
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.entities.knowledge_entities import IndexingEstimate
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import (
|
||||
integrate_fields,
|
||||
integrate_icon_fields,
|
||||
integrate_list_fields,
|
||||
integrate_notion_info_list_fields,
|
||||
integrate_page_fields,
|
||||
integrate_workspace_fields,
|
||||
)
|
||||
from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from libs.helper import dump_response, to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import Account, DataSourceOauthBinding, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
from ..wraps import account_initialization_required, setup_required, with_current_tenant_id, with_current_user
|
||||
|
||||
|
||||
class NotionEstimatePayload(BaseModel):
|
||||
@ -54,50 +50,74 @@ class DataSourceNotionPreviewQuery(BaseModel):
|
||||
credential_id: str = Field(..., description="Credential ID", min_length=1)
|
||||
|
||||
|
||||
register_schema_model(console_ns, NotionEstimatePayload)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse, TextContentResponse)
|
||||
class DataSourceIntegrateIconResponse(ResponseModel):
|
||||
type: str | None = None
|
||||
url: str | None = None
|
||||
emoji: str | None = None
|
||||
|
||||
|
||||
integrate_icon_model = get_or_create_model("DataSourceIntegrateIcon", integrate_icon_fields)
|
||||
class DataSourceIntegratePageResponse(ResponseModel):
|
||||
page_name: str
|
||||
page_id: str
|
||||
page_icon: DataSourceIntegrateIconResponse | None
|
||||
parent_id: str
|
||||
type: str
|
||||
|
||||
integrate_page_fields_copy = integrate_page_fields.copy()
|
||||
integrate_page_fields_copy["page_icon"] = fields.Nested(integrate_icon_model, allow_null=True)
|
||||
integrate_page_model = get_or_create_model("DataSourceIntegratePage", integrate_page_fields_copy)
|
||||
|
||||
integrate_workspace_fields_copy = integrate_workspace_fields.copy()
|
||||
integrate_workspace_fields_copy["pages"] = fields.List(fields.Nested(integrate_page_model))
|
||||
integrate_workspace_model = get_or_create_model("DataSourceIntegrateWorkspace", integrate_workspace_fields_copy)
|
||||
class DataSourceIntegrateWorkspaceResponse(ResponseModel):
|
||||
workspace_name: str | None
|
||||
workspace_id: str | None
|
||||
workspace_icon: str | None
|
||||
pages: list[DataSourceIntegratePageResponse]
|
||||
total: int
|
||||
|
||||
integrate_fields_copy = integrate_fields.copy()
|
||||
integrate_fields_copy["source_info"] = fields.Nested(integrate_workspace_model)
|
||||
integrate_model = get_or_create_model("DataSourceIntegrate", integrate_fields_copy)
|
||||
|
||||
integrate_list_fields_copy = integrate_list_fields.copy()
|
||||
integrate_list_fields_copy["data"] = fields.List(fields.Nested(integrate_model))
|
||||
integrate_list_model = get_or_create_model("DataSourceIntegrateList", integrate_list_fields_copy)
|
||||
class DataSourceIntegrateResponse(ResponseModel):
|
||||
id: str | None
|
||||
provider: str
|
||||
created_at: datetime | int | None
|
||||
is_bound: bool
|
||||
disabled: bool | None
|
||||
link: str
|
||||
source_info: DataSourceIntegrateWorkspaceResponse | None
|
||||
|
||||
notion_page_fields = {
|
||||
"page_name": fields.String,
|
||||
"page_id": fields.String,
|
||||
"page_icon": fields.Nested(integrate_icon_model, allow_null=True),
|
||||
"is_bound": fields.Boolean,
|
||||
"parent_id": fields.String,
|
||||
"type": fields.String,
|
||||
}
|
||||
notion_page_model = get_or_create_model("NotionIntegratePage", notion_page_fields)
|
||||
@field_serializer("created_at")
|
||||
def serialize_created_at(self, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
notion_workspace_fields = {
|
||||
"workspace_name": fields.String,
|
||||
"workspace_id": fields.String,
|
||||
"workspace_icon": fields.String,
|
||||
"pages": fields.List(fields.Nested(notion_page_model)),
|
||||
}
|
||||
notion_workspace_model = get_or_create_model("NotionIntegrateWorkspace", notion_workspace_fields)
|
||||
|
||||
integrate_notion_info_list_fields_copy = integrate_notion_info_list_fields.copy()
|
||||
integrate_notion_info_list_fields_copy["notion_info"] = fields.List(fields.Nested(notion_workspace_model))
|
||||
integrate_notion_info_list_model = get_or_create_model(
|
||||
"NotionIntegrateInfoList", integrate_notion_info_list_fields_copy
|
||||
class DataSourceIntegrateListResponse(ResponseModel):
|
||||
data: list[DataSourceIntegrateResponse]
|
||||
|
||||
|
||||
class NotionIntegratePageResponse(ResponseModel):
|
||||
page_name: str
|
||||
page_id: str
|
||||
page_icon: DataSourceIntegrateIconResponse | None
|
||||
parent_id: str | None
|
||||
type: str
|
||||
is_bound: bool
|
||||
|
||||
|
||||
class NotionIntegrateWorkspaceResponse(ResponseModel):
|
||||
workspace_name: str | None
|
||||
workspace_id: str | None
|
||||
workspace_icon: str | None
|
||||
pages: list[NotionIntegratePageResponse]
|
||||
|
||||
|
||||
class NotionIntegrateInfoListResponse(ResponseModel):
|
||||
notion_info: list[NotionIntegrateWorkspaceResponse]
|
||||
|
||||
|
||||
register_schema_models(console_ns, NotionEstimatePayload)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
DataSourceIntegrateListResponse,
|
||||
IndexingEstimate,
|
||||
NotionIntegrateInfoListResponse,
|
||||
SimpleResultResponse,
|
||||
TextContentResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -109,10 +129,9 @@ class DataSourceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_list_model)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.response(200, "Success", console_ns.models[DataSourceIntegrateListResponse.__name__])
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str) -> tuple[dict[str, Any], int]:
|
||||
# get workspace data source integrates
|
||||
data_source_integrates = db.session.scalars(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@ -154,19 +173,21 @@ class DataSourceApi(Resource):
|
||||
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
|
||||
}
|
||||
)
|
||||
return {"data": integrate_data}, 200
|
||||
return dump_response(DataSourceIntegrateListResponse, {"data": integrate_data}), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
binding_id = str(binding_id)
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self, current_tenant_id: str, binding_id: UUID, action: Literal["enable", "disable"]
|
||||
) -> tuple[dict[str, str], int]:
|
||||
binding_id_str = str(binding_id)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.tenant_id == current_tenant_id
|
||||
DataSourceOauthBinding.id == binding_id_str, DataSourceOauthBinding.tenant_id == current_tenant_id
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if data_source_binding is None:
|
||||
@ -198,12 +219,12 @@ class DataSourceNotionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_notion_info_list_model)
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
@console_ns.doc(params=query_params_from_model(DataSourceNotionListQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[NotionIntegrateInfoListResponse.__name__])
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account) -> tuple[dict[str, Any], int]:
|
||||
query = DataSourceNotionListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -278,22 +299,22 @@ class DataSourceNotionListApi(Resource):
|
||||
pages.append(page_info)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return {"notion_info": {**workspace_info, "pages": pages}}, 200
|
||||
notion_info = [{**workspace_info, "pages": pages}] if workspace_info else []
|
||||
return dump_response(NotionIntegrateInfoListResponse, {"notion_info": notion_info}), 200
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/notion/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||
"/datasets/notion-indexing-estimate",
|
||||
)
|
||||
class DataSourceNotionApi(Resource):
|
||||
@console_ns.route("/notion/pages/<uuid:page_id>/<string:page_type>/preview")
|
||||
class DataSourceNotionPreviewApi(Resource):
|
||||
"""Preview one authorized Notion page through the datasource credential."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.doc(params=query_params_from_model(DataSourceNotionPreviewQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
|
||||
def get(self, page_id: UUID, page_type: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict())
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, page_id: UUID, page_type: str) -> tuple[dict[str, str], int]:
|
||||
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
@ -316,13 +337,18 @@ class DataSourceNotionApi(Resource):
|
||||
text_docs = extractor.extract()
|
||||
return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/notion-indexing-estimate")
|
||||
class DataSourceNotionIndexingEstimateApi(Resource):
|
||||
"""Estimate indexing work for selected Notion pages."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.response(200, "Success", console_ns.models[IndexingEstimate.__name__])
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str) -> tuple[dict[str, Any], int]:
|
||||
payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump()
|
||||
# validate args
|
||||
@ -355,7 +381,7 @@ class DataSourceNotionApi(Resource):
|
||||
args["doc_form"],
|
||||
args["doc_language"],
|
||||
)
|
||||
return response.model_dump(), 200
|
||||
return dump_response(IndexingEstimate, response), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/notion/sync")
|
||||
@ -364,7 +390,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id: UUID) -> tuple[dict[str, str], int]:
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -382,7 +408,7 @@ class DataSourceNotionDocumentSyncApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
def get(self, dataset_id: UUID, document_id: UUID) -> tuple[dict[str, str], int]:
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
@ -44,8 +44,8 @@ from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import dump_response, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from libs.login import login_required
|
||||
from models import Account, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from models.enums import IndexingStatus, SegmentStatus
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
@ -71,6 +71,8 @@ from ..wraps import (
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -169,8 +171,9 @@ register_response_schema_models(
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
def get_document(
|
||||
self, dataset_id: str, document_id: str, current_user: Account, current_tenant_id: str
|
||||
) -> Document:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -190,8 +193,7 @@ class DocumentResource(Resource):
|
||||
|
||||
return document
|
||||
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def get_batch_documents(self, dataset_id: str, batch: str, current_user: Account) -> Sequence[Document]:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -218,8 +220,8 @@ class GetProcessRuleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
req_data = request.args
|
||||
|
||||
document_id = req_data.get("document_id")
|
||||
@ -279,8 +281,9 @@ class DatasetDocumentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
raw_args = request.args.to_dict()
|
||||
param = DocumentDatasetListParam.model_validate(raw_args)
|
||||
@ -405,8 +408,8 @@ class DatasetDocumentListApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -480,9 +483,10 @@ class DatasetInitApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -539,11 +543,12 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
@ -604,10 +609,11 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, batch: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, batch: str):
|
||||
dataset_id_str = str(dataset_id)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch, current_user)
|
||||
if not documents:
|
||||
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
@ -704,9 +710,10 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, batch: str):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, dataset_id: UUID, batch: str):
|
||||
dataset_id_str = str(dataset_id)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch, current_user)
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
@ -759,16 +766,18 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
completed_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id_str),
|
||||
DocumentSegment.document_id == document_id_str,
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
@ -777,7 +786,7 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
total_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document_id_str),
|
||||
DocumentSegment.document_id == document_id_str,
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
@ -820,10 +829,12 @@ class DocumentApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
metadata = request.args.get("metadata", "all")
|
||||
if metadata not in self.METADATA_CHOICES:
|
||||
@ -909,7 +920,9 @@ class DocumentApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -918,7 +931,7 @@ class DocumentApi(DocumentResource):
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
try:
|
||||
DocumentService.delete_document(document)
|
||||
@ -939,9 +952,11 @@ class DocumentDownloadApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def get(self, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
|
||||
# Reuse the shared permission/tenant checks implemented in DocumentResource.
|
||||
document = self.get_document(str(dataset_id), str(document_id))
|
||||
document = self.get_document(str(dataset_id), str(document_id), current_user, current_tenant_id)
|
||||
return {"url": DocumentService.get_document_download_url(document)}
|
||||
|
||||
|
||||
@ -956,12 +971,13 @@ class DocumentBatchDownloadZipApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
|
||||
"""Stream a ZIP archive containing the requested uploaded documents."""
|
||||
# Parse and validate request payload.
|
||||
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
|
||||
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
|
||||
@ -1003,11 +1019,19 @@ class DocumentProcessingApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["pause", "resume"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
action: Literal["pause", "resume"],
|
||||
):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
@ -1051,11 +1075,12 @@ class DocumentMetadataApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
|
||||
|
||||
@ -1100,8 +1125,10 @@ class DocumentStatusApi(DocumentResource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def patch(
|
||||
self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]
|
||||
):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -1216,8 +1243,6 @@ class DocumentRetryApi(DocumentResource):
|
||||
raise NotFound("Dataset not found.")
|
||||
for document_id in payload.document_ids:
|
||||
try:
|
||||
document_id = str(document_id)
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
@ -1248,9 +1273,9 @@ class DocumentRenameApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -1273,9 +1298,9 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""sync website document."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if not dataset:
|
||||
@ -1351,7 +1376,8 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id: UUID):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
"""
|
||||
Generate summary index for specified documents.
|
||||
|
||||
@ -1359,7 +1385,6 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
|
||||
then asynchronously generates summary indexes for the provided documents.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
# Get dataset
|
||||
@ -1444,7 +1469,8 @@ class DocumentSummaryStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
"""
|
||||
Get summary index generation status for a document.
|
||||
|
||||
@ -1457,7 +1483,6 @@ class DocumentSummaryStatusApi(DocumentResource):
|
||||
- not_started: Number of segments without summary records
|
||||
- summaries: List of summary records with status and content preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
|
||||
|
||||
@ -33,6 +33,8 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
@ -51,7 +53,8 @@ from fields.segment_fields import (
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import dump_response, escape_like_pattern
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
@ -164,9 +167,9 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -274,9 +277,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
|
||||
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
|
||||
@console_ns.response(204, "Segments deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -312,9 +314,16 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["enable", "disable"]):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
action: Literal["enable", "disable"],
|
||||
):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if not dataset:
|
||||
@ -373,9 +382,9 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
|
||||
@console_ns.response(200, "Segment created successfully", console_ns.models[SegmentDetailResponse.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -431,9 +440,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Segment updated successfully", console_ns.models[SegmentDetailResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -500,9 +511,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
|
||||
@console_ns.response(204, "Segment deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -548,9 +561,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -619,9 +632,11 @@ class ChildChunkAddApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
|
||||
@console_ns.response(200, "Child chunk created successfully", console_ns.models[ChildChunkDetailResponse.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -677,9 +692,8 @@ class ChildChunkAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -731,9 +745,11 @@ class ChildChunkAddApi(Resource):
|
||||
console_ns.models[ChildChunkBatchUpdateResponse.__name__],
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ChildChunkBatchUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -781,9 +797,17 @@ class ChildChunkUpdateApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
|
||||
@console_ns.response(204, "Child chunk deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
segment_id: UUID,
|
||||
child_chunk_id: UUID,
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -840,9 +864,17 @@ class ChildChunkUpdateApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
|
||||
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Child chunk updated successfully", console_ns.models[ChildChunkDetailResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
segment_id: UUID,
|
||||
child_chunk_id: UUID,
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.console.wraps import (
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from fields.dataset_fields import (
|
||||
dataset_detail_fields,
|
||||
@ -29,7 +30,8 @@ from fields.dataset_fields import (
|
||||
vector_setting_fields,
|
||||
weighted_score_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
@ -152,8 +154,9 @@ class ExternalApiTemplateListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||
def post(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
ExternalDatasetService.validate_api_list(payload.settings)
|
||||
@ -182,8 +185,8 @@ class ExternalApiTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
|
||||
external_knowledge_api_id_str, current_tenant_id
|
||||
@ -197,8 +200,9 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||
def patch(self, external_knowledge_api_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
|
||||
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||
@ -217,8 +221,9 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(204, "External knowledge API deleted successfully")
|
||||
def delete(self, external_knowledge_api_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
@ -237,8 +242,8 @@ class ExternalApiUseCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
|
||||
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
|
||||
@ -259,9 +264,10 @@ class ExternalDatasetCreateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
@ -293,8 +299,8 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
|
||||
@ -9,11 +9,18 @@ from configs import dify_config
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
@ -66,11 +73,10 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, provider_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, provider_id: str):
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
credential_id = request.args.get("credential_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
@ -174,9 +180,8 @@ class DatasourceAuth(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -195,15 +200,17 @@ class DatasourceAuth(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, user: Account, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasources = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
user=user,
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
@ -216,9 +223,8 @@ class DatasourceAuthDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
@ -241,9 +247,8 @@ class DatasourceAuthUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -264,9 +269,8 @@ class DatasourceAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
@ -277,9 +281,8 @@ class DatasourceHardCodeAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
@ -292,9 +295,8 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -310,9 +312,8 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def delete(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_oauth_custom_client_params(
|
||||
@ -330,9 +331,8 @@ class DatasourceAuthDefaultApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -352,9 +352,8 @@ class DatasourceUpdateProviderNameApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
@ -1,13 +1,20 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.fields import SimpleDataResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import (
|
||||
JsonResponseWithStatus,
|
||||
query_params_from_model,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -16,79 +23,132 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineTemplateListQuery(BaseModel):
|
||||
type: str = Field(default="built-in", description="Template source: built-in or customized")
|
||||
language: str = Field(default="en-US", description="Template language")
|
||||
|
||||
|
||||
class PipelineTemplateDetailQuery(BaseModel):
|
||||
type: str = Field(default="built-in", description="Template source: built-in or customized")
|
||||
|
||||
|
||||
class PipelineTemplateItemResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
icon: dict[str, Any]
|
||||
description: str
|
||||
position: int
|
||||
chunk_structure: str
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
|
||||
|
||||
class PipelineTemplateListResponse(ResponseModel):
|
||||
pipeline_templates: list[PipelineTemplateItemResponse]
|
||||
|
||||
|
||||
class PipelineTemplateDetailResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
icon_info: dict[str, Any]
|
||||
description: str
|
||||
chunk_structure: str
|
||||
export_data: str
|
||||
graph: dict[str, Any]
|
||||
created_by: str | None = None
|
||||
|
||||
|
||||
class CustomizedPipelineTemplatePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field(default="", max_length=400)
|
||||
icon_info: dict[str, object] = Field(default_factory=lambda: IconInfo(icon="").model_dump())
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
CustomizedPipelineTemplatePayload,
|
||||
PipelineTemplateDetailQuery,
|
||||
PipelineTemplateListQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
PipelineTemplateDetailResponse,
|
||||
PipelineTemplateListResponse,
|
||||
SimpleDataResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/templates")
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(PipelineTemplateListQuery))
|
||||
@console_ns.response(200, "Pipeline templates", console_ns.models[PipelineTemplateListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
language = request.args.get("language", default="en-US", type=str)
|
||||
def get(self) -> JsonResponseWithStatus:
|
||||
query = PipelineTemplateListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
# get pipeline templates
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||
return pipeline_templates, 200
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(query.type, query.language)
|
||||
return dump_response(PipelineTemplateListResponse, pipeline_templates), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/templates/<string:template_id>")
|
||||
class PipelineTemplateDetailApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(PipelineTemplateDetailQuery))
|
||||
@console_ns.response(200, "Pipeline template", console_ns.models[PipelineTemplateDetailResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self, template_id: str):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
def get(self, template_id: str) -> JsonResponseWithStatus:
|
||||
query = PipelineTemplateDetailQuery.model_validate(request.args.to_dict(flat=True))
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
|
||||
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, query.type)
|
||||
if pipeline_template is None:
|
||||
return {"error": "Pipeline template not found from upstream service."}, 404
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
class Payload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field(default="", max_length=400)
|
||||
icon_info: dict[str, object] | None = None
|
||||
|
||||
|
||||
register_schema_models(console_ns, Payload)
|
||||
register_response_schema_models(console_ns, SimpleDataResponse)
|
||||
raise NotFound("Pipeline template not found from upstream service.")
|
||||
return dump_response(PipelineTemplateDetailResponse, pipeline_template), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[CustomizedPipelineTemplatePayload.__name__])
|
||||
@console_ns.response(204, "Pipeline template updated")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
payload = Payload.model_validate(console_ns.payload or {})
|
||||
def patch(self, template_id: str) -> tuple[str, int]:
|
||||
payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {})
|
||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
return "", 204
|
||||
|
||||
@console_ns.response(204, "Pipeline template deleted")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, template_id: str):
|
||||
def delete(self, template_id: str) -> tuple[str, int]:
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||
return 200
|
||||
return "", 204
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleDataResponse.__name__])
|
||||
def post(self, template_id: str):
|
||||
def post(self, template_id: str) -> JsonResponseWithStatus:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
template = session.scalar(
|
||||
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).limit(1)
|
||||
@ -96,19 +156,20 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
if not template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
|
||||
return {"data": template.yaml_content}, 200
|
||||
return dump_response(SimpleDataResponse, {"data": template.yaml_content}), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
|
||||
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[Payload.__name__])
|
||||
@console_ns.expect(console_ns.models[CustomizedPipelineTemplatePayload.__name__])
|
||||
@console_ns.response(204, "Pipeline template published")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@knowledge_pipeline_publish_enabled
|
||||
def post(self, pipeline_id: str):
|
||||
payload = Payload.model_validate(console_ns.payload or {})
|
||||
def post(self, pipeline_id: str) -> tuple[str, int]:
|
||||
payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {})
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
|
||||
return {"result": "success"}
|
||||
return "", 204
|
||||
|
||||
@ -1,20 +1,25 @@
|
||||
from flask_restx import Resource, marshal
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.common.schema import JsonResponseWithStatus, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import RagPipelineImportResponse
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from fields.dataset_fields import DatasetDetailResponse
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
@ -25,19 +30,26 @@ class RagPipelineDatasetImportPayload(BaseModel):
|
||||
yaml_content: str
|
||||
|
||||
|
||||
register_schema_model(console_ns, RagPipelineDatasetImportPayload)
|
||||
register_schema_models(console_ns, RagPipelineDatasetImportPayload)
|
||||
register_response_schema_models(console_ns, DatasetDetailResponse, RagPipelineImportResponse)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/dataset")
|
||||
class CreateRagPipelineDatasetApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__])
|
||||
@console_ns.response(
|
||||
201,
|
||||
"RAG pipeline dataset import started",
|
||||
console_ns.models[RagPipelineImportResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account) -> JsonResponseWithStatus:
|
||||
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
@ -70,19 +82,20 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return import_info, 201
|
||||
return dump_response(RagPipelineImportResponse, import_info), 201
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/empty-dataset")
|
||||
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@console_ns.response(201, "RAG pipeline dataset created", console_ns.models[DatasetDetailResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account) -> JsonResponseWithStatus:
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
@ -99,4 +112,4 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
partial_member_list=None,
|
||||
),
|
||||
)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
return dump_response(DatasetDetailResponse, dataset), 201
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NoReturn
|
||||
from typing import Any, Concatenate, NoReturn
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
@ -57,7 +57,9 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
|
||||
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||
|
||||
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
def _api_prerequisite[T, **P, R](
|
||||
f: Callable[Concatenate[T, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@ -72,10 +74,10 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@ -1,9 +1,15 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with # type: ignore
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.common.fields import SimpleDataResponse
|
||||
from controllers.common.schema import (
|
||||
JsonResponseWithStatus,
|
||||
query_params_from_model,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
@ -12,12 +18,10 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from extensions.ext_database import db
|
||||
from fields.rag_pipeline_fields import (
|
||||
leaked_dependency_fields,
|
||||
pipeline_import_check_dependencies_fields,
|
||||
pipeline_import_fields,
|
||||
)
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
@ -38,34 +42,44 @@ class RagPipelineImportPayload(BaseModel):
|
||||
|
||||
|
||||
class IncludeSecretQuery(BaseModel):
|
||||
include_secret: str = Field(default="false")
|
||||
include_secret: str = Field(default="false", description="Whether to include secret values in the exported DSL")
|
||||
|
||||
|
||||
class RagPipelineImportResponse(ResponseModel):
|
||||
id: str
|
||||
status: ImportStatus
|
||||
pipeline_id: str | None = None
|
||||
dataset_id: str | None = None
|
||||
current_dsl_version: str
|
||||
imported_dsl_version: str
|
||||
error: str = ""
|
||||
|
||||
|
||||
class RagPipelineImportCheckDependenciesResponse(ResponseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
|
||||
|
||||
|
||||
pipeline_import_model = get_or_create_model("RagPipelineImport", pipeline_import_fields)
|
||||
|
||||
leaked_dependency_model = get_or_create_model("RagPipelineLeakedDependency", leaked_dependency_fields)
|
||||
pipeline_import_check_dependencies_fields_copy = pipeline_import_check_dependencies_fields.copy()
|
||||
pipeline_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(
|
||||
fields.Nested(leaked_dependency_model)
|
||||
)
|
||||
pipeline_import_check_dependencies_model = get_or_create_model(
|
||||
"RagPipelineImportCheckDependencies", pipeline_import_check_dependencies_fields_copy
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
RagPipelineImportCheckDependenciesResponse,
|
||||
RagPipelineImportResponse,
|
||||
SimpleDataResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports")
|
||||
class RagPipelineImportApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
|
||||
@console_ns.response(200, "Import completed", console_ns.models[RagPipelineImportResponse.__name__])
|
||||
@console_ns.response(202, "Import pending confirmation", console_ns.models[RagPipelineImportResponse.__name__])
|
||||
@console_ns.response(400, "Import failed", console_ns.models[RagPipelineImportResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_model)
|
||||
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
def post(self, current_user: Account) -> JsonResponseWithStatus:
|
||||
# Check user role first
|
||||
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -93,22 +107,23 @@ class RagPipelineImportApi(Resource):
|
||||
status = result.status
|
||||
match status:
|
||||
case ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return dump_response(RagPipelineImportResponse, result), 400
|
||||
case ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return dump_response(RagPipelineImportResponse, result), 202
|
||||
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
|
||||
return result.model_dump(mode="json"), 200
|
||||
return dump_response(RagPipelineImportResponse, result), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
|
||||
class RagPipelineImportConfirmApi(Resource):
|
||||
@console_ns.response(200, "Import confirmed", console_ns.models[RagPipelineImportResponse.__name__])
|
||||
@console_ns.response(400, "Import failed", console_ns.models[RagPipelineImportResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_model)
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, import_id: str):
|
||||
def post(self, current_user: Account, import_id: str) -> JsonResponseWithStatus:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
account = current_user
|
||||
@ -120,34 +135,40 @@ class RagPipelineImportConfirmApi(Resource):
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
return dump_response(RagPipelineImportResponse, result), 400
|
||||
return dump_response(RagPipelineImportResponse, result), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports/<string:pipeline_id>/check-dependencies")
|
||||
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Dependencies checked",
|
||||
console_ns.models[RagPipelineImportCheckDependenciesResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_check_dependencies_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
def get(self, pipeline: Pipeline) -> JsonResponseWithStatus:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
result = import_service.check_dependencies(pipeline=pipeline)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
return dump_response(RagPipelineImportCheckDependenciesResponse, result), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<string:pipeline_id>/exports")
|
||||
class RagPipelineExportApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(IncludeSecretQuery))
|
||||
@console_ns.response(200, "Pipeline exported", console_ns.models[SimpleDataResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, pipeline: Pipeline):
|
||||
def get(self, pipeline: Pipeline) -> JsonResponseWithStatus:
|
||||
# Add include_secret params
|
||||
query = IncludeSecretQuery.model_validate(request.args.to_dict())
|
||||
|
||||
@ -157,4 +178,4 @@ class RagPipelineExportApi(Resource):
|
||||
pipeline=pipeline, include_secret=query.include_secret == "true"
|
||||
)
|
||||
|
||||
return {"data": result}, 200
|
||||
return dump_response(SimpleDataResponse, {"data": result}), 200
|
||||
|
||||
@ -18,6 +18,7 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user_id
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
@ -135,20 +136,18 @@ class CompletionApi(InstalledAppResource):
|
||||
)
|
||||
class CompletionStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
@ -215,7 +214,8 @@ class ChatApi(InstalledAppResource):
|
||||
)
|
||||
class ChatStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
@ -223,13 +223,10 @@ class ChatStopApi(InstalledAppResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
|
||||
@ -12,14 +12,19 @@ from controllers.common.fields import SimpleMessageResponse, SimpleResultMessage
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from libs.login import login_required
|
||||
from models import Account, App, InstalledApp, RecommendedApp
|
||||
from models.model import IconType
|
||||
from services.account_service import TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -131,9 +136,10 @@ class InstalledAppsListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
|
||||
def get(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if query.app_id:
|
||||
installed_apps = db.session.scalars(
|
||||
@ -212,7 +218,8 @@ class InstalledAppsListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleMessageResponse.__name__])
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
recommended_app = db.session.scalar(
|
||||
@ -221,8 +228,6 @@ class InstalledAppsListApi(Resource):
|
||||
if recommended_app is None:
|
||||
raise NotFound("Recommended app not found")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
app = db.session.get(App, payload.app_id)
|
||||
|
||||
if app is None:
|
||||
@ -262,8 +267,8 @@ class InstalledAppApi(InstalledAppResource):
|
||||
"""
|
||||
|
||||
@console_ns.response(204, "App uninstalled successfully")
|
||||
def delete(self, installed_app: InstalledApp):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, installed_app: InstalledApp):
|
||||
if installed_app.app_owner_tenant_id == current_tenant_id:
|
||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||
|
||||
|
||||
@ -14,7 +14,13 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, model_validate, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
model_validate,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
@ -23,8 +29,8 @@ from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from libs.login import login_required
|
||||
from models import Account, App
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
@ -48,9 +54,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
"""Console API for getting human input form definition."""
|
||||
|
||||
@staticmethod
|
||||
def _ensure_console_access(form: Form):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
def _ensure_console_access(form: Form, current_tenant_id: str) -> None:
|
||||
"""Ensure a console form token resolves only inside the current tenant."""
|
||||
if form.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("App not found")
|
||||
|
||||
@ -62,7 +67,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, form_token: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, form_token: str):
|
||||
"""
|
||||
Get human input form definition by form token.
|
||||
|
||||
@ -73,15 +79,23 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
self._ensure_console_access(form, current_tenant_id)
|
||||
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
@model_validate(HumanInputFormSubmitPayload)
|
||||
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
def post(self, payload: HumanInputFormSubmitPayload, form_token: str):
|
||||
def post(
|
||||
self,
|
||||
payload: HumanInputFormSubmitPayload,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
form_token: str,
|
||||
):
|
||||
"""
|
||||
Submit human input form by form token.
|
||||
|
||||
@ -95,14 +109,12 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
self._ensure_console_access(form, current_tenant_id)
|
||||
self._ensure_console_recipient_type(form)
|
||||
recipient_type = form.recipient_type
|
||||
# The type checker is not smart enought to validate the following invariant.
|
||||
@ -126,7 +138,9 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def get(self, workflow_run_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow execution events stream after resume.
|
||||
|
||||
@ -134,8 +148,6 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
|
||||
Returns Server-Sent Events stream.
|
||||
"""
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
|
||||
@ -13,7 +13,7 @@ from controllers.common.errors import (
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import with_current_user
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
@ -36,9 +36,9 @@ class GetRemoteFileInfo(Resource):
|
||||
@login_required
|
||||
def get(self, url: str):
|
||||
decoded_url = helpers.decode_remote_url(url, request.query_string)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
resp = remote_fetcher.make_request("HEAD", decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp = remote_fetcher.make_request("GET", decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
return RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
@ -58,9 +58,9 @@ class RemoteFileUpload(Resource):
|
||||
|
||||
# Try to fetch remote file metadata/content first
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
resp = remote_fetcher.make_request("HEAD", url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
resp = remote_fetcher.make_request("GET", url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# Normalize into a user-friendly error message expected by tests
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
@ -74,7 +74,7 @@ class RemoteFileUpload(Resource):
|
||||
raise FileTooLargeError()
|
||||
|
||||
# Load content if needed
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
content = resp.content if resp.request.method == "GET" else remote_fetcher.make_request("GET", url).content
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
|
||||
@ -1,160 +0,0 @@
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class SnippetListQuery(BaseModel):
|
||||
"""Query parameters for listing snippets."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
keyword: str | None = None
|
||||
is_published: bool | None = Field(default=None, description="Filter by published status")
|
||||
creators: list[str] | None = Field(default=None, description="Filter by creator account IDs")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs")
|
||||
|
||||
@field_validator("creators", mode="before")
|
||||
@classmethod
|
||||
def parse_creators(cls, value: object) -> list[str] | None:
|
||||
"""Normalize creators filter from query string or list input."""
|
||||
return cls._normalize_string_list(value)
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@classmethod
|
||||
def parse_tag_ids(cls, value: object) -> list[str] | None:
|
||||
"""Normalize and validate tag IDs from query string or list input."""
|
||||
items = cls._normalize_string_list(value)
|
||||
if not items:
|
||||
return None
|
||||
try:
|
||||
return [str(uuid.UUID(item)) for item in items]
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
@staticmethod
|
||||
def _normalize_string_list(value: object) -> list[str] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return [item.strip() for item in value.split(",") if item.strip()] or None
|
||||
if isinstance(value, list):
|
||||
return [str(item).strip() for item in value if str(item).strip()] or None
|
||||
return None
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
"""Icon information model."""
|
||||
|
||||
icon: str | None = None
|
||||
icon_type: Literal["emoji", "image"] | None = None
|
||||
icon_background: str | None = None
|
||||
icon_url: str | None = None
|
||||
|
||||
|
||||
class InputFieldDefinition(BaseModel):
|
||||
"""Input field definition for snippet parameters."""
|
||||
|
||||
default: str | None = None
|
||||
hint: bool | None = None
|
||||
label: str | None = None
|
||||
max_length: int | None = None
|
||||
options: list[str] | None = None
|
||||
placeholder: str | None = None
|
||||
required: bool | None = None
|
||||
type: str | None = None # e.g., "text-input"
|
||||
|
||||
|
||||
class CreateSnippetPayload(BaseModel):
|
||||
"""Payload for creating a new snippet."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
type: Literal["node", "group"] = "node"
|
||||
icon_info: IconInfo | None = None
|
||||
graph: dict[str, Any] | None = None
|
||||
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
|
||||
|
||||
|
||||
class UpdateSnippetPayload(BaseModel):
|
||||
"""Payload for updating a snippet."""
|
||||
|
||||
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
icon_info: IconInfo | None = None
|
||||
|
||||
|
||||
class SnippetDraftSyncPayload(BaseModel):
|
||||
"""Payload for syncing snippet draft workflow."""
|
||||
|
||||
graph: dict[str, Any]
|
||||
hash: str | None = None
|
||||
conversation_variables: list[dict[str, Any]] | None = Field(
|
||||
default=None,
|
||||
description="Ignored. Snippet workflows do not persist conversation variables.",
|
||||
)
|
||||
input_fields: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetWorkflowListQuery(BaseModel):
|
||||
"""Query parameters for listing snippet published workflows."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
"""Query parameters for workflow runs."""
|
||||
|
||||
last_id: str | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SnippetDraftRunPayload(BaseModel):
|
||||
"""Payload for running snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetDraftNodeRunPayload(BaseModel):
|
||||
"""Payload for running a single node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetIterationNodeRunPayload(BaseModel):
|
||||
"""Payload for running an iteration node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetLoopNodeRunPayload(BaseModel):
|
||||
"""Payload for running a loop node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PublishWorkflowPayload(BaseModel):
|
||||
"""Payload for publishing snippet workflow."""
|
||||
|
||||
knowledge_base_setting: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetImportPayload(BaseModel):
|
||||
"""Payload for importing snippet from DSL."""
|
||||
|
||||
mode: str = Field(..., description="Import mode: yaml-content or yaml-url")
|
||||
yaml_content: str | None = Field(default=None, description="YAML content (required for yaml-content mode)")
|
||||
yaml_url: str | None = Field(default=None, description="YAML URL (required for yaml-url mode)")
|
||||
name: str | None = Field(default=None, description="Override snippet name")
|
||||
description: str | None = Field(default=None, description="Override snippet description")
|
||||
snippet_id: str | None = Field(default=None, description="Snippet ID to update (optional)")
|
||||
|
||||
|
||||
class IncludeSecretQuery(BaseModel):
|
||||
"""Query parameter for including secret variables in export."""
|
||||
|
||||
include_secret: str = Field(default="false", description="Whether to include secret variables")
|
||||
@ -1,638 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.workflow import (
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE,
|
||||
WorkflowPaginationResponse,
|
||||
WorkflowResponse,
|
||||
)
|
||||
from controllers.console.snippets.payloads import (
|
||||
PublishWorkflowPayload,
|
||||
SnippetDraftNodeRunPayload,
|
||||
SnippetDraftRunPayload,
|
||||
SnippetDraftSyncPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
SnippetWorkflowListQuery,
|
||||
WorkflowRunQuery,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.workflow_run_fields import (
|
||||
WorkflowRunDetailResponse,
|
||||
WorkflowRunNodeExecutionListResponse,
|
||||
WorkflowRunNodeExecutionResponse,
|
||||
WorkflowRunPaginationResponse,
|
||||
)
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
from services.snippet_service import SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
class SnippetWorkflowResponse(WorkflowResponse):
|
||||
input_fields: list[dict] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SnippetDraftSyncPayload,
|
||||
SnippetDraftNodeRunPayload,
|
||||
SnippetDraftRunPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
SnippetWorkflowListQuery,
|
||||
WorkflowRunQuery,
|
||||
PublishWorkflowPayload,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
SnippetWorkflowResponse,
|
||||
WorkflowPaginationResponse,
|
||||
WorkflowRunPaginationResponse,
|
||||
WorkflowRunDetailResponse,
|
||||
WorkflowRunNodeExecutionListResponse,
|
||||
WorkflowRunNodeExecutionResponse,
|
||||
)
|
||||
|
||||
|
||||
class SnippetNotFoundError(Exception):
|
||||
"""Snippet not found error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_snippet(view_func: Callable[P, R]):
|
||||
"""Decorator to fetch and validate snippet access."""
|
||||
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if not kwargs.get("snippet_id"):
|
||||
raise ValueError("missing snippet_id in path parameters")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet_id = str(kwargs.get("snippet_id"))
|
||||
del kwargs["snippet_id"]
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=snippet_id,
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
kwargs["snippet"] = snippet
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
|
||||
class SnippetDraftWorkflowApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_workflow")
|
||||
@console_ns.response(200, "Draft workflow retrieved successfully", console_ns.models[SnippetWorkflowResponse.__name__])
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get draft workflow for snippet."""
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
|
||||
if not workflow:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
db.session.expunge(workflow)
|
||||
workflow.conversation_variables = []
|
||||
workflow.input_fields = snippet.input_fields_list
|
||||
return SnippetWorkflowResponse.model_validate(workflow, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("sync_snippet_draft_workflow")
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
|
||||
@console_ns.response(200, "Draft workflow synced successfully")
|
||||
@console_ns.response(400, "Hash mismatch")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""Sync draft workflow for snippet."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.sync_draft_workflow(
|
||||
snippet=snippet,
|
||||
graph=payload.graph,
|
||||
unique_hash=payload.hash,
|
||||
account=current_user,
|
||||
input_fields=payload.input_fields,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
|
||||
class SnippetDraftConfigApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_config")
|
||||
@console_ns.response(200, "Draft config retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get snippet draft workflow configuration limits."""
|
||||
return {
|
||||
"parallel_depth_limit": 3,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
|
||||
class SnippetPublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("get_snippet_published_workflow")
|
||||
@console_ns.response(200, "Published workflow retrieved successfully", console_ns.models[SnippetWorkflowResponse.__name__])
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get published workflow for snippet."""
|
||||
if not snippet.is_published:
|
||||
return None
|
||||
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_published_workflow(snippet=snippet)
|
||||
|
||||
if not workflow:
|
||||
return None
|
||||
|
||||
workflow.input_fields = snippet.input_fields_list
|
||||
return SnippetWorkflowResponse.model_validate(workflow, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("publish_snippet_workflow")
|
||||
@console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
|
||||
@console_ns.response(200, "Workflow published successfully")
|
||||
@console_ns.response(400, "No draft workflow found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""Publish snippet workflow."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
snippet_service = SnippetService()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
try:
|
||||
workflow = snippet_service.publish_workflow(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
account=current_user,
|
||||
)
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
session.commit()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
|
||||
class SnippetDefaultBlockConfigsApi(Resource):
|
||||
@console_ns.doc("get_snippet_default_block_configs")
|
||||
@console_ns.response(200, "Default block configs retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get default block configurations for snippet workflow."""
|
||||
snippet_service = SnippetService()
|
||||
return snippet_service.get_default_block_configs()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows")
|
||||
class SnippetPublishedAllWorkflowApi(Resource):
|
||||
@console_ns.expect(console_ns.models[SnippetWorkflowListQuery.__name__])
|
||||
@console_ns.doc("get_all_snippet_published_workflows")
|
||||
@console_ns.doc(description="Get all published workflows for a snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
@console_ns.response(200, "Published workflows retrieved successfully", console_ns.models[WorkflowPaginationResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get all published workflow versions for snippet."""
|
||||
args = SnippetWorkflowListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
snippet_service = SnippetService()
|
||||
with Session(db.engine) as session:
|
||||
workflows, has_more = snippet_service.get_all_published_workflows(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return WorkflowPaginationResponse.model_validate(
|
||||
{
|
||||
"items": workflows,
|
||||
"page": args.page,
|
||||
"limit": args.limit,
|
||||
"has_more": has_more,
|
||||
},
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/<string:workflow_id>/restore")
|
||||
class SnippetDraftWorkflowRestoreApi(Resource):
|
||||
@console_ns.doc("restore_snippet_workflow_to_draft")
|
||||
@console_ns.doc(description="Restore a published snippet workflow version into the draft workflow")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "workflow_id": "Published workflow ID"})
|
||||
@console_ns.response(200, "Workflow restored successfully")
|
||||
@console_ns.response(400, "Source workflow must be published")
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, workflow_id: str):
|
||||
"""Restore a published snippet workflow version into the draft workflow."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
snippet_service = SnippetService()
|
||||
|
||||
try:
|
||||
workflow = snippet_service.restore_published_workflow_to_draft(
|
||||
snippet=snippet,
|
||||
workflow_id=workflow_id,
|
||||
account=current_user,
|
||||
)
|
||||
except IsDraftWorkflowError as exc:
|
||||
raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
|
||||
except WorkflowNotFoundError as exc:
|
||||
raise NotFound(str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise BadRequest(str(exc)) from exc
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
|
||||
class SnippetWorkflowRunsApi(Resource):
|
||||
@console_ns.doc("list_snippet_workflow_runs")
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", console_ns.models[WorkflowRunPaginationResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""List workflow runs for snippet."""
|
||||
query = WorkflowRunQuery.model_validate(
|
||||
{
|
||||
"last_id": request.args.get("last_id"),
|
||||
"limit": request.args.get("limit", type=int, default=20),
|
||||
}
|
||||
)
|
||||
args = {
|
||||
"last_id": query.last_id,
|
||||
"limit": query.limit,
|
||||
}
|
||||
|
||||
snippet_service = SnippetService()
|
||||
result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
|
||||
|
||||
return WorkflowRunPaginationResponse.model_validate(result, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
|
||||
class SnippetWorkflowRunDetailApi(Resource):
|
||||
@console_ns.doc("get_snippet_workflow_run_detail")
|
||||
@console_ns.response(200, "Workflow run detail retrieved successfully", console_ns.models[WorkflowRunDetailResponse.__name__])
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
def get(self, snippet: CustomizedSnippet, run_id):
|
||||
"""Get workflow run detail for snippet."""
|
||||
run_id = str(run_id)
|
||||
|
||||
snippet_service = SnippetService()
|
||||
workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
|
||||
|
||||
if not workflow_run:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
return WorkflowRunDetailResponse.model_validate(workflow_run, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
class SnippetWorkflowRunNodeExecutionsApi(Resource):
|
||||
@console_ns.doc("list_snippet_workflow_run_node_executions")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Node executions retrieved successfully",
|
||||
console_ns.models[WorkflowRunNodeExecutionListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
def get(self, snippet: CustomizedSnippet, run_id):
|
||||
"""List node executions for a workflow run."""
|
||||
run_id = str(run_id)
|
||||
|
||||
snippet_service = SnippetService()
|
||||
node_executions = snippet_service.get_snippet_workflow_run_node_executions(
|
||||
snippet=snippet,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
return WorkflowRunNodeExecutionListResponse.model_validate(
|
||||
{"data": node_executions}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class SnippetDraftNodeRunApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_node")
|
||||
@console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
|
||||
@console_ns.response(
|
||||
200, "Node run completed successfully", console_ns.models[WorkflowRunNodeExecutionResponse.__name__]
|
||||
)
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a single node in snippet draft workflow.
|
||||
|
||||
Executes a specific node with provided inputs for single-step debugging.
|
||||
Returns the node execution result including status, outputs, and timing.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
user_inputs = payload.inputs
|
||||
|
||||
# Get draft workflow for file parsing
|
||||
snippet_service = SnippetService()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise NotFound("Draft workflow not found")
|
||||
|
||||
files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
|
||||
|
||||
workflow_node_execution = SnippetGenerateService.run_draft_node(
|
||||
snippet=snippet,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=current_user,
|
||||
query=payload.query,
|
||||
files=files,
|
||||
)
|
||||
|
||||
return WorkflowRunNodeExecutionResponse.model_validate(
|
||||
workflow_node_execution, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
||||
class SnippetDraftNodeLastRunApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_node_last_run")
|
||||
@console_ns.doc(description="Get last run result for a node in snippet draft workflow")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.response(
|
||||
200, "Node last run retrieved successfully", console_ns.models[WorkflowRunNodeExecutionResponse.__name__]
|
||||
)
|
||||
@console_ns.response(404, "Snippet, draft workflow, or node last run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
def get(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Get the last run result for a specific node in snippet draft workflow.
|
||||
|
||||
Returns the most recent execution record for the given node,
|
||||
including status, inputs, outputs, and timing information.
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise NotFound("Draft workflow not found")
|
||||
|
||||
node_exec = snippet_service.get_snippet_node_last_run(
|
||||
snippet=snippet,
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
)
|
||||
if node_exec is None:
|
||||
raise NotFound("Node last run not found")
|
||||
|
||||
return WorkflowRunNodeExecutionResponse.model_validate(node_exec, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class SnippetDraftRunIterationNodeApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node for snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Iteration node run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow iteration node for snippet.
|
||||
|
||||
Iteration nodes execute their internal sub-graph multiple times over an input list.
|
||||
Returns an SSE event stream with iteration progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate_single_iteration(
|
||||
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class SnippetDraftRunLoopNodeApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node for snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Loop node run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow loop node for snippet.
|
||||
|
||||
Loop nodes execute their internal sub-graph repeatedly until a condition is met.
|
||||
Returns an SSE event stream with loop progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate_single_loop(
|
||||
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
|
||||
class SnippetDraftWorkflowRunApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_workflow")
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
|
||||
@console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""
|
||||
Run draft workflow for snippet.
|
||||
|
||||
Executes the snippet's draft workflow with the provided inputs
|
||||
and returns an SSE event stream with execution progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate(
|
||||
snippet=snippet,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
class SnippetWorkflowTaskStopApi(Resource):
|
||||
@console_ns.doc("stop_snippet_workflow_task")
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, task_id: str):
|
||||
"""
|
||||
Stop a running snippet workflow task.
|
||||
|
||||
Uses both the legacy stop flag mechanism and the graph engine
|
||||
command channel for backward compatibility.
|
||||
"""
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
@ -1,319 +0,0 @@
|
||||
"""
|
||||
Snippet draft workflow variable APIs.
|
||||
|
||||
Mirrors console app routes under /apps/.../workflows/draft/variables for snippet scope,
|
||||
using CustomizedSnippet.id as WorkflowDraftVariable.app_id (same invariant as snippet execution).
|
||||
|
||||
Snippet workflows do not expose system variables (`node_id == sys`) or conversation variables
|
||||
(`node_id == conversation`): paginated list queries exclude those rows; single-variable GET/PATCH/DELETE/reset
|
||||
reject them; `GET .../system-variables` and `GET .../conversation-variables` return empty lists for API parity.
|
||||
Other routes mirror `workflow_draft_variable` app APIs under `/snippets/...`.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal, marshal_with
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
WorkflowDraftVariableListQuery,
|
||||
WorkflowDraftVariableUpdatePayload,
|
||||
_ensure_variable_access,
|
||||
_file_access_controller,
|
||||
validate_node_id,
|
||||
workflow_draft_variable_list_model,
|
||||
workflow_draft_variable_list_without_value_model,
|
||||
workflow_draft_variable_model,
|
||||
)
|
||||
from controllers.console.snippets.snippet_workflow import get_snippet
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.login import current_user, login_required
|
||||
from models.snippet import CustomizedSnippet
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.snippet_service import SnippetService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: frozenset[str] = frozenset(
|
||||
{SYSTEM_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID}
|
||||
)
|
||||
|
||||
|
||||
def _ensure_snippet_draft_variable_row_allowed(
|
||||
*,
|
||||
variable: WorkflowDraftVariable,
|
||||
variable_id: str,
|
||||
) -> None:
|
||||
"""Snippet scope only supports canvas-node draft variables; treat sys/conversation rows as not found."""
|
||||
if variable.node_id in _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
|
||||
def _snippet_draft_var_prerequisite(f: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Setup, auth, snippet resolution, and tenant edit permission (same stack as snippet workflow APIs)."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables")
|
||||
class SnippetWorkflowVariableCollectionApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
|
||||
@console_ns.doc("get_snippet_workflow_variables")
|
||||
@console_ns.doc(description="List draft workflow variables without values (paginated, snippet scope)")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow variables retrieved successfully",
|
||||
workflow_draft_variable_list_without_value_model,
|
||||
)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
snippet_service = SnippetService()
|
||||
if snippet_service.get_draft_workflow(snippet=snippet) is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=session)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=snippet.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
user_id=current_user.id,
|
||||
exclude_node_ids=_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
||||
@console_ns.doc("delete_snippet_workflow_variables")
|
||||
@console_ns.doc(description="Delete all draft workflow variables for the current user (snippet scope)")
|
||||
@console_ns.response(204, "Workflow variables deleted successfully")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, snippet: CustomizedSnippet) -> Response:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
draft_var_srv.delete_user_workflow_variables(snippet.id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
class SnippetNodeVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_node_variables")
|
||||
@console_ns.doc(description="Get variables for a specific node (snippet draft workflow)")
|
||||
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList:
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=session)
|
||||
node_vars = draft_var_srv.list_node_variables(snippet.id, node_id, user_id=current_user.id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@console_ns.doc("delete_snippet_node_variables")
|
||||
@console_ns.doc(description="Delete all variables for a specific node (snippet draft workflow)")
|
||||
@console_ns.response(204, "Node variables deleted successfully")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, snippet: CustomizedSnippet, node_id: str) -> Response:
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(snippet.id, node_id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
class SnippetVariableApi(Resource):
|
||||
@console_ns.doc("get_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Get a specific draft workflow variable (snippet scope)")
|
||||
@console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
return variable
|
||||
|
||||
@console_ns.doc("update_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Update a draft workflow variable (snippet scope)")
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def patch(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
new_name = args_model.name
|
||||
raw_value = args_model.value
|
||||
if new_name is None and raw_value is None:
|
||||
return variable
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=snippet.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=snippet.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@console_ns.doc("delete_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Delete a draft workflow variable (snippet scope)")
|
||||
@console_ns.response(204, "Variable deleted successfully")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, snippet: CustomizedSnippet, variable_id: str) -> Response:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class SnippetVariableResetApi(Resource):
|
||||
@console_ns.doc("reset_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Reset a draft workflow variable to its default value (snippet scope)")
|
||||
@console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(204, "Variable reset (no content)")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def put(self, snippet: CustomizedSnippet, variable_id: str) -> Response | Any:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
snippet_service = SnippetService()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, snippet_id={snippet.id}",
|
||||
)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
if resetted is None:
|
||||
return Response("", 204)
|
||||
return marshal(resetted, workflow_draft_variable_model)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/conversation-variables")
|
||||
class SnippetConversationVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_conversation_variables")
|
||||
@console_ns.doc(
|
||||
description="Conversation variables are not used in snippet workflows; returns an empty list for API parity"
|
||||
)
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/system-variables")
|
||||
class SnippetSystemVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_system_variables")
|
||||
@console_ns.doc(
|
||||
description="System variables are not used in snippet workflows; returns an empty list for API parity"
|
||||
)
|
||||
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/environment-variables")
|
||||
class SnippetEnvironmentVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_environment_variables")
|
||||
@console_ns.doc(description="Get environment variables from snippet draft workflow graph")
|
||||
@console_ns.response(200, "Environment variables retrieved successfully")
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def get(self, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]:
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if workflow is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
env_vars_list: list[dict[str, Any]] = []
|
||||
for v in workflow.environment_variables:
|
||||
env_vars_list.append(
|
||||
{
|
||||
"id": v.id,
|
||||
"type": "env",
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.exposed_type().value,
|
||||
"value": v.value,
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"editable": True,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
@ -51,7 +51,7 @@ class TagBindingRemovePayload(BaseModel):
|
||||
|
||||
|
||||
class TagListQueryParam(BaseModel):
|
||||
type: Literal["knowledge", "app", "snippet", ""] = Field("", description="Tag type filter")
|
||||
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
|
||||
keyword: str | None = Field(None, description="Search keyword")
|
||||
|
||||
|
||||
@ -96,10 +96,7 @@ class TagListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"type": 'Tag type filter. Can be "knowledge", "app", or "snippet".',
|
||||
"keyword": "Search keyword for tag name.",
|
||||
}
|
||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||
)
|
||||
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
|
||||
@with_current_tenant_id
|
||||
|
||||
@ -18,7 +18,7 @@ from controllers.common.fields import (
|
||||
SimpleResultResponse,
|
||||
VerificationTokenResponse,
|
||||
)
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
@ -42,15 +42,17 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import Account as AccountResponse
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from libs.helper import EmailStr, dump_response, extract_remote_ip, timezone, to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import Account, AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus, InvitationCodeStatus
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
@ -173,7 +175,6 @@ class CheckEmailUniquePayload(BaseModel):
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AccountResponse,
|
||||
AccountInitPayload,
|
||||
AccountNamePayload,
|
||||
AccountAvatarPayload,
|
||||
@ -245,6 +246,7 @@ register_schema_models(
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AccountResponse,
|
||||
AvatarUrlResponse,
|
||||
SimpleResultDataResponse,
|
||||
SimpleResultResponse,
|
||||
@ -258,9 +260,8 @@ class AccountInitApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
if account.status == "active":
|
||||
raise AccountAlreadyInitedError()
|
||||
|
||||
@ -306,8 +307,8 @@ class AccountProfileApi(Resource):
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
return _serialize_account(current_user)
|
||||
|
||||
|
||||
@ -318,8 +319,8 @@ class AccountNameApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountNamePayload.model_validate(payload)
|
||||
updated_account = AccountService.update_account(current_user, name=args.name)
|
||||
@ -329,20 +330,21 @@ class AccountNameApi(Resource):
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AccountAvatarQuery.__name__])
|
||||
@console_ns.doc("get_account_avatar")
|
||||
@console_ns.doc(description="Get account avatar url")
|
||||
@console_ns.doc(params=query_params_from_model(AccountAvatarQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[AvatarUrlResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True))
|
||||
avatar = args.avatar
|
||||
|
||||
if avatar.startswith(("http://", "https://")):
|
||||
return {"avatar_url": avatar}
|
||||
return dump_response(AvatarUrlResponse, {"avatar_url": avatar})
|
||||
|
||||
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == avatar).limit(1))
|
||||
if upload_file is None:
|
||||
@ -355,15 +357,15 @@ class AccountAvatarApi(Resource):
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {"avatar_url": avatar_url}
|
||||
return dump_response(AvatarUrlResponse, {"avatar_url": avatar_url})
|
||||
|
||||
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountAvatarPayload.model_validate(payload)
|
||||
|
||||
@ -379,8 +381,8 @@ class AccountInterfaceLanguageApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceLanguagePayload.model_validate(payload)
|
||||
|
||||
@ -396,8 +398,8 @@ class AccountInterfaceThemeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceThemePayload.model_validate(payload)
|
||||
|
||||
@ -413,8 +415,8 @@ class AccountTimezoneApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountTimezonePayload.model_validate(payload)
|
||||
|
||||
@ -430,8 +432,8 @@ class AccountPasswordApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountPasswordPayload.model_validate(payload)
|
||||
|
||||
@ -449,9 +451,8 @@ class AccountIntegrateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
account_integrates = db.session.scalars(
|
||||
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
|
||||
).all()
|
||||
@ -495,9 +496,8 @@ class AccountDeleteVerifyApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
token, code = AccountService.generate_account_deletion_verification_code(account)
|
||||
AccountService.send_account_deletion_verification_email(account, code)
|
||||
|
||||
@ -511,9 +511,8 @@ class AccountDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountDeletePayload.model_validate(payload)
|
||||
|
||||
@ -547,9 +546,8 @@ class EducationVerifyApi(Resource):
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
return EducationVerifyResponse.model_validate(
|
||||
BillingService.EducationIdentity.verify(account.id, account.email) or {}
|
||||
).model_dump(mode="json")
|
||||
@ -563,9 +561,8 @@ class EducationApi(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = EducationActivatePayload.model_validate(payload)
|
||||
|
||||
@ -577,9 +574,8 @@ class EducationApi(Resource):
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
res = BillingService.EducationIdentity.status(account.id) or {}
|
||||
# convert expire_at to UTC timestamp from isoformat
|
||||
if res and "expire_at" in res:
|
||||
@ -613,8 +609,8 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailSendPayload.model_validate(payload)
|
||||
|
||||
@ -673,8 +669,8 @@ class ChangeEmailCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailValidityPayload.model_validate(payload)
|
||||
|
||||
@ -720,7 +716,8 @@ class ChangeEmailResetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailResetPayload.model_validate(payload)
|
||||
normalized_new_email = args.new_email.lower()
|
||||
@ -731,7 +728,6 @@ class ChangeEmailResetApi(Resource):
|
||||
if not AccountService.check_email_unique(normalized_new_email):
|
||||
raise EmailAlreadyInUseError()
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
@ -1,9 +1,15 @@
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
@ -19,14 +25,10 @@ class AgentProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
return jsonable_encoder(AgentService.list_agent_providers(current_user.id, current_tenant_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
|
||||
@ -42,6 +44,7 @@ class AgentProviderApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, provider_name: str):
|
||||
return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))
|
||||
|
||||
@ -14,10 +14,16 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
|
||||
@ -96,17 +102,15 @@ register_schema_models(
|
||||
)
|
||||
|
||||
|
||||
def _create_endpoint() -> dict[str, bool]:
|
||||
"""Create a plugin endpoint for the current workspace."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
def _create_endpoint(tenant_id: str, user_id: str) -> dict[str, bool]:
|
||||
"""Create a plugin endpoint for the injected workspace and user."""
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
@ -116,16 +120,14 @@ def _create_endpoint() -> dict[str, bool]:
|
||||
raise ValueError(e.description) from e
|
||||
|
||||
|
||||
def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
def _update_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]:
|
||||
"""Update a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
@ -133,14 +135,12 @@ def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
}
|
||||
|
||||
|
||||
def _delete_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
def _delete_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]:
|
||||
"""Delete a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
}
|
||||
@ -163,8 +163,10 @@ class EndpointCollectionApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_endpoint()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
return _create_endpoint(tenant_id=tenant_id, user_id=user_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
@ -189,8 +191,10 @@ class DeprecatedEndpointCreateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_endpoint()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
return _create_endpoint(tenant_id=tenant_id, user_id=user_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/list")
|
||||
@ -206,9 +210,9 @@ class EndpointListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user_id: str):
|
||||
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
page = args.page
|
||||
@ -218,7 +222,7 @@ class EndpointListApi(Resource):
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@ -239,9 +243,9 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user_id: str):
|
||||
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
page = args.page
|
||||
@ -252,7 +256,7 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints_for_single_plugin(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
@ -278,8 +282,10 @@ class EndpointItemApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, id: str):
|
||||
return _delete_endpoint(endpoint_id=id)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def delete(self, tenant_id: str, user_id: str, id: str):
|
||||
return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id)
|
||||
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@ -295,8 +301,10 @@ class EndpointItemApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def patch(self, id: str):
|
||||
return _update_endpoint(endpoint_id=id)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, user_id: str, id: str):
|
||||
return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
@ -322,9 +330,11 @@ class DeprecatedEndpointDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
return _delete_endpoint(endpoint_id=args.endpoint_id)
|
||||
return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/update")
|
||||
@ -350,9 +360,11 @@ class DeprecatedEndpointUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
return _update_endpoint(endpoint_id=args.endpoint_id)
|
||||
return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/enable")
|
||||
@ -370,14 +382,14 @@ class EndpointEnableApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
@ -397,13 +409,13 @@ class EndpointDisableApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
@ -4,11 +4,16 @@ from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import TenantAccountRole
|
||||
from libs.login import login_required
|
||||
from models import Account, TenantAccountRole
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
@ -29,8 +34,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, provider: str):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
@ -72,8 +78,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str, config_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, provider: str, config_id: str):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
@ -8,12 +8,19 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
@ -95,10 +102,8 @@ class ModelProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
args = ParserModelList.model_validate(payload)
|
||||
|
||||
@ -114,9 +119,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
# if credential_id is not provided, return current used credential
|
||||
payload = request.args.to_dict(flat=True)
|
||||
args = ParserCredentialId.model_validate(payload)
|
||||
@ -133,8 +137,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialCreate.model_validate(payload)
|
||||
|
||||
@ -157,9 +161,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialUpdate.model_validate(payload)
|
||||
|
||||
@ -184,8 +187,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialDelete.model_validate(payload)
|
||||
|
||||
@ -205,8 +208,8 @@ class ModelProviderCredentialSwitchApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialSwitch.model_validate(payload)
|
||||
|
||||
@ -225,8 +228,8 @@ class ModelProviderValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialValidate.model_validate(payload)
|
||||
|
||||
@ -280,11 +283,8 @@ class PreferredProviderTypeUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserPreferredProviderType.model_validate(payload)
|
||||
|
||||
@ -301,10 +301,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, provider: str):
|
||||
if provider != "anthropic":
|
||||
raise ValueError(f"provider name {provider} is invalid")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
data = BillingService.get_model_provider_payment_link(
|
||||
provider_name=provider,
|
||||
|
||||
@ -13,12 +13,14 @@ from controllers.console.wraps import (
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
@ -193,7 +195,7 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider):
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
|
||||
|
||||
@ -269,8 +271,9 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
def get(self, tenant_id: str, user: Account, provider: str):
|
||||
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -292,9 +295,13 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
)
|
||||
|
||||
if args.config_from == "predefined-model":
|
||||
# Only the predefined-model branch needs visibility filtering by user.
|
||||
# The account is injected once by the handler and only passed into the
|
||||
# service branch that needs user-scoped credential visibility.
|
||||
available_credentials = model_provider_service.get_provider_available_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
available_credentials = model_provider_service.get_provider_model_available_credentials(
|
||||
|
||||
@ -1,407 +0,0 @@
|
||||
import logging
|
||||
import re
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.datastructures import MultiDict
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.snippets.payloads import (
|
||||
CreateSnippetPayload,
|
||||
IncludeSecretQuery,
|
||||
SnippetImportPayload,
|
||||
SnippetListQuery,
|
||||
UpdateSnippetPayload,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.snippet import SnippetType
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.snippet_dsl_service import SnippetDslService
|
||||
from services.snippet_service import SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
|
||||
|
||||
|
||||
def _normalize_snippet_list_query_args(query_args: MultiDict[str, str]) -> dict[str, str | list[str]]:
|
||||
normalized: dict[str, str | list[str]] = {}
|
||||
indexed_tag_ids: list[tuple[int, str]] = []
|
||||
|
||||
for key in query_args:
|
||||
match = _TAG_IDS_BRACKET_PATTERN.fullmatch(key)
|
||||
if match:
|
||||
indexed_tag_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key))
|
||||
continue
|
||||
|
||||
value = query_args.get(key)
|
||||
if value is not None:
|
||||
normalized[key] = value
|
||||
|
||||
if indexed_tag_ids:
|
||||
normalized["tag_ids"] = [value for _, value in sorted(indexed_tag_ids)]
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SnippetListQuery,
|
||||
CreateSnippetPayload,
|
||||
UpdateSnippetPayload,
|
||||
SnippetImportPayload,
|
||||
IncludeSecretQuery,
|
||||
)
|
||||
|
||||
# Create namespace models for marshaling
|
||||
snippet_model = console_ns.model("Snippet", snippet_fields)
|
||||
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
|
||||
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets")
|
||||
class CustomizedSnippetsApi(Resource):
|
||||
@console_ns.doc("list_customized_snippets")
|
||||
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
|
||||
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""List customized snippets with pagination and search."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
query = SnippetListQuery.model_validate(_normalize_snippet_list_query_args(request.args))
|
||||
|
||||
snippets, total, has_more = SnippetService.get_snippets(
|
||||
tenant_id=current_tenant_id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
is_published=query.is_published,
|
||||
creators=query.creators,
|
||||
tag_ids=query.tag_ids,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": marshal(snippets, snippet_list_fields),
|
||||
"page": query.page,
|
||||
"limit": query.limit,
|
||||
"total": total,
|
||||
"has_more": has_more,
|
||||
}, 200
|
||||
|
||||
@console_ns.doc("create_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
|
||||
@console_ns.response(201, "Snippet created successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
"""Create a new customized snippet."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
snippet_type = SnippetType(payload.type)
|
||||
except ValueError:
|
||||
snippet_type = SnippetType.NODE
|
||||
|
||||
try:
|
||||
if payload.graph is not None:
|
||||
SnippetService.validate_snippet_graph_forbidden_nodes(payload.graph)
|
||||
|
||||
snippet = SnippetService.create_snippet(
|
||||
tenant_id=current_tenant_id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
snippet_type=snippet_type,
|
||||
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
|
||||
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
|
||||
account=current_user,
|
||||
)
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return marshal(snippet, snippet_fields), 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
|
||||
class CustomizedSnippetDetailApi(Resource):
|
||||
@console_ns.doc("get_customized_snippet")
|
||||
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, snippet_id: str):
|
||||
"""Get customized snippet details."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
return marshal(snippet, snippet_fields), 200
|
||||
|
||||
@console_ns.doc("update_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
|
||||
@console_ns.response(200, "Snippet updated successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, snippet_id: str):
|
||||
"""Update customized snippet."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
if "icon_info" in update_data and update_data["icon_info"] is not None:
|
||||
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
snippet = session.merge(snippet)
|
||||
snippet = SnippetService.update_snippet(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
account_id=current_user.id,
|
||||
data=update_data,
|
||||
)
|
||||
session.commit()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return marshal(snippet, snippet_fields), 200
|
||||
|
||||
@console_ns.doc("delete_customized_snippet")
|
||||
@console_ns.response(204, "Snippet deleted successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, snippet_id: str):
|
||||
"""Delete customized snippet."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
SnippetService.delete_snippet(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/export")
|
||||
class CustomizedSnippetExportApi(Resource):
|
||||
@console_ns.doc("export_customized_snippet")
|
||||
@console_ns.doc(description="Export snippet configuration as DSL")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID to export"})
|
||||
@console_ns.response(200, "Snippet exported successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, snippet_id: str):
|
||||
"""Export snippet as DSL."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
# Get include_secret parameter
|
||||
query = IncludeSecretQuery.model_validate(request.args.to_dict())
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = SnippetDslService(session)
|
||||
result = export_service.export_snippet_dsl(snippet=snippet, include_secret=query.include_secret == "true")
|
||||
|
||||
# Set filename with .snippet extension
|
||||
filename = f"{snippet.name}.snippet"
|
||||
encoded_filename = quote(filename)
|
||||
|
||||
response = Response(
|
||||
result,
|
||||
mimetype="application/x-yaml",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Type"] = "application/x-yaml"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/imports")
|
||||
class CustomizedSnippetImportApi(Resource):
|
||||
@console_ns.doc("import_customized_snippet")
|
||||
@console_ns.doc(description="Import snippet from DSL")
|
||||
@console_ns.expect(console_ns.models.get(SnippetImportPayload.__name__))
|
||||
@console_ns.response(200, "Snippet imported successfully")
|
||||
@console_ns.response(202, "Import pending confirmation")
|
||||
@console_ns.response(400, "Import failed")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
"""Import snippet from DSL."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = SnippetImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.import_snippet(
|
||||
account=current_user,
|
||||
import_mode=payload.mode,
|
||||
yaml_content=payload.yaml_content,
|
||||
yaml_url=payload.yaml_url,
|
||||
snippet_id=payload.snippet_id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/imports/<string:import_id>/confirm")
|
||||
class CustomizedSnippetImportConfirmApi(Resource):
|
||||
@console_ns.doc("confirm_snippet_import")
|
||||
@console_ns.doc(description="Confirm a pending snippet import")
|
||||
@console_ns.doc(params={"import_id": "Import ID to confirm"})
|
||||
@console_ns.response(200, "Import confirmed successfully")
|
||||
@console_ns.response(400, "Import failed")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, import_id: str):
|
||||
"""Confirm a pending snippet import."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.confirm_import(import_id=import_id, account=current_user)
|
||||
session.commit()
|
||||
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/check-dependencies")
|
||||
class CustomizedSnippetCheckDependenciesApi(Resource):
|
||||
@console_ns.doc("check_snippet_dependencies")
|
||||
@console_ns.doc(description="Check dependencies for a snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
@console_ns.response(200, "Dependencies checked successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, snippet_id: str):
|
||||
"""Check dependencies for a snippet."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.check_dependencies(snippet=snippet)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/use-count/increment")
|
||||
class CustomizedSnippetUseCountIncrementApi(Resource):
|
||||
@console_ns.doc("increment_snippet_use_count")
|
||||
@console_ns.doc(description="Increment snippet use count by 1")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
@console_ns.response(200, "Use count incremented successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, snippet_id: str):
|
||||
"""Increment snippet use count when it is inserted into a workflow."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
SnippetService.increment_use_count(session=session, snippet=snippet)
|
||||
session.commit()
|
||||
session.refresh(snippet)
|
||||
|
||||
return {"result": "success", "use_count": snippet.use_count}, 200
|
||||
@ -69,6 +69,7 @@ class BuiltinToolAddPayload(BaseModel):
|
||||
credentials: dict[str, Any]
|
||||
name: str | None = Field(default=None, max_length=30)
|
||||
type: CredentialType
|
||||
visibility: str | None = None
|
||||
|
||||
|
||||
class BuiltinToolUpdatePayload(BaseModel):
|
||||
@ -277,7 +278,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(
|
||||
@ -293,7 +294,7 @@ class ToolBuiltinProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||
@ -306,7 +307,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {})
|
||||
@ -324,7 +325,7 @@ class ToolBuiltinProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
@ -338,6 +339,7 @@ class ToolBuiltinProviderAddApi(Resource):
|
||||
credentials=payload.credentials,
|
||||
name=payload.name,
|
||||
api_type=CredentialType.of(payload.type),
|
||||
visibility=payload.visibility,
|
||||
)
|
||||
|
||||
|
||||
@ -348,7 +350,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
user_id = user.id
|
||||
|
||||
@ -370,13 +372,20 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
def get(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
# Optional list of credential IDs to include even if visibility would hide them
|
||||
# (used when a workflow/agent node still references another member's only_me credential).
|
||||
include_credential_ids = request.args.getlist("include_credential_ids") or [
|
||||
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
|
||||
]
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
user=user,
|
||||
include_credential_ids=include_credential_ids or None,
|
||||
)
|
||||
)
|
||||
|
||||
@ -384,7 +393,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
||||
class ToolBuiltinProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
|
||||
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
@ -784,7 +793,7 @@ class ToolPluginOAuthApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
tool_provider = ToolProviderID(provider)
|
||||
plugin_id = tool_provider.plugin_id
|
||||
provider_name = tool_provider.provider_name
|
||||
@ -822,7 +831,7 @@ class ToolPluginOAuthApi(Resource):
|
||||
@console_ns.route("/oauth/plugin/<path:provider>/tool/callback")
|
||||
class ToolOAuthCallback(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
@ -859,7 +868,7 @@ class ToolOAuthCallback(Resource):
|
||||
if not credentials:
|
||||
raise Exception("the plugin credentials failed")
|
||||
|
||||
# add credentials to database
|
||||
# add credentials to database — OAuth tokens default to only_me since they're personal
|
||||
BuiltinToolManageService.add_builtin_tool_provider(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
@ -867,6 +876,7 @@ class ToolOAuthCallback(Resource):
|
||||
credentials=dict(credentials),
|
||||
expires_at=expires_at,
|
||||
api_type=CredentialType.OAUTH2,
|
||||
visibility="only_me",
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
@ -878,7 +888,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
@ -910,7 +920,7 @@ class ToolOAuthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||
@ -919,7 +929,7 @@ class ToolOAuthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||
@ -931,7 +941,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
|
||||
@ -945,13 +955,18 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
def get(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
include_credential_ids = request.args.getlist("include_credential_ids") or [
|
||||
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
|
||||
]
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
user=user,
|
||||
include_credential_ids=include_credential_ids or None,
|
||||
)
|
||||
)
|
||||
|
||||
@ -1151,7 +1166,7 @@ class ToolMCPDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
def get(self, provider_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
@ -1180,7 +1195,7 @@ class ToolMCPUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
def get(self, provider_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
|
||||
@ -77,7 +77,7 @@ class TriggerProviderIconApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
@ -103,7 +103,7 @@ class TriggerProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
"""Get info for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
@ -119,15 +119,18 @@ class TriggerSubscriptionListApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
"""List all trigger subscriptions for the current tenant's provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.list_trigger_provider_subscriptions(
|
||||
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
@ -146,7 +149,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
"""Add a new subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -175,7 +178,7 @@ class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
def get(self, provider: str, subscription_builder_id: str):
|
||||
"""Get a subscription instance for a trigger provider"""
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
|
||||
@ -191,7 +194,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
def post(self, provider: str, subscription_builder_id: str):
|
||||
"""Verify and update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -223,7 +226,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
def post(self, provider: str, subscription_builder_id: str):
|
||||
"""Update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
@ -257,7 +260,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
def get(self, provider: str, subscription_builder_id: str):
|
||||
"""Get the request logs for a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
@ -280,7 +283,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
def post(self, provider: str, subscription_builder_id: str):
|
||||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -404,7 +407,7 @@ class TriggerOAuthAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
"""Initiate OAuth authorization flow for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
@ -486,7 +489,7 @@ class TriggerOAuthAuthorizeApi(Resource):
|
||||
@console_ns.route("/oauth/plugin/<path:provider>/trigger/callback")
|
||||
class TriggerOAuthCallbackApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
"""Handle OAuth callback for trigger provider"""
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
@ -554,7 +557,7 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
"""Get OAuth client configuration for a provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -600,7 +603,7 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
"""Configure custom OAuth client for a provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -626,7 +629,7 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
def delete(self, provider: str):
|
||||
"""Remove custom OAuth client configuration"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -654,7 +657,7 @@ class TriggerSubscriptionVerifyApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_id):
|
||||
def post(self, provider: str, subscription_id: str):
|
||||
"""Verify credentials for an existing subscription (edit mode only)"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
@ -25,13 +25,15 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
only_edition_enterprise,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import TimestampField, dump_response, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
|
||||
from libs.login import login_required
|
||||
from models.account import Account, Tenant, TenantCustomConfigDict, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.billing_service import BillingService, SubscriptionPlan
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -153,8 +155,9 @@ class TenantListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
tenants = TenantService.get_join_tenants(current_user)
|
||||
tenant_dicts = []
|
||||
is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED
|
||||
@ -228,11 +231,11 @@ class TenantApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__])
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
if request.path == "/info":
|
||||
logger.warning("Deprecated URL /info was used.")
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tenant = current_user.current_tenant
|
||||
if not tenant:
|
||||
raise ValueError("No current tenant")
|
||||
@ -256,8 +259,8 @@ class SwitchWorkspaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = SwitchWorkspacePayload.model_validate(payload)
|
||||
|
||||
@ -281,8 +284,8 @@ class CustomConfigWorkspaceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("workspace_custom")
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = WorkspaceCustomConfigPayload.model_validate(payload)
|
||||
tenant = db.get_or_404(Tenant, current_tenant_id)
|
||||
@ -308,8 +311,8 @@ class WebappLogoWorkspaceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("workspace_custom")
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
@ -349,8 +352,8 @@ class WorkspaceInfoApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
# Change workspace name
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = WorkspaceInfoPayload.model_validate(payload)
|
||||
|
||||
@ -372,13 +375,12 @@ class WorkspacePermissionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_enterprise
|
||||
def get(self):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
"""
|
||||
Get workspace permission settings.
|
||||
Returns permission flags that control workspace features like member invitations and owner transfer.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate
|
||||
from typing import Any, Concatenate, overload
|
||||
|
||||
from flask import abort, request
|
||||
from pydantic import BaseModel, ValidationError
|
||||
@ -37,9 +37,21 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
|
||||
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
|
||||
|
||||
|
||||
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@overload
|
||||
def account_initialization_required[T, **P, R](
|
||||
view: Callable[Concatenate[T, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
def account_initialization_required[R](view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated(*args: Any, **kwargs: Any) -> R:
|
||||
# The overloads keep Resource methods method-aware for pyrefly while
|
||||
# preserving support for plain functions used in tests and utilities.
|
||||
# check account initialization
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.status == AccountStatus.UNINITIALIZED:
|
||||
@ -218,9 +230,21 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@overload
|
||||
def setup_required[T, **P, R](
|
||||
view: Callable[Concatenate[T, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
def setup_required[R](view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated(*args: Any, **kwargs: Any) -> R:
|
||||
# The overloads keep Resource methods method-aware for pyrefly while
|
||||
# preserving support for plain functions used in tests and utilities.
|
||||
# check setup
|
||||
if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)):
|
||||
if os.environ.get("INIT_PASSWORD"):
|
||||
@ -552,7 +576,7 @@ def with_current_user_id[T, **P, R](
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
return view(self, str(current_user.id), *args, **kwargs)
|
||||
return view(self, current_user.id, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from hmac import new as hmac_new
|
||||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from core.db.session_factory import session_factory
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
@ -44,6 +44,8 @@ def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
|
||||
|
||||
def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Inject an EndUser for valid inner API HMAC auth, otherwise pass the request through unchanged."""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.INNER_API:
|
||||
@ -72,9 +74,9 @@ def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P,
|
||||
if signature_base64 != token:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
kwargs["user"] = db.session.get(EndUser, user_id)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
with session_factory.create_session() as session:
|
||||
kwargs["user"] = session.get(EndUser, user_id)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@ -147,7 +147,7 @@ class AppDescribeApi(AppReadResource):
|
||||
class AppListApi(Resource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
|
||||
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@auth_router.guard_workspace(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
try:
|
||||
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.conditions import (
|
||||
EDITION_CE,
|
||||
EDITION_EE,
|
||||
HAS_ALLOWED_ROLES,
|
||||
LOADED_APP_IS_PRIVATE,
|
||||
PATH_HAS_APP_ID,
|
||||
WEBAPP_AUTH_ENABLED,
|
||||
WORKSPACE_MEMBERSHIP_REQUIRED,
|
||||
WORKSPACE_SCOPED,
|
||||
)
|
||||
from controllers.openapi.auth.data import Edition
|
||||
from controllers.openapi.auth.flow import When
|
||||
@ -15,14 +17,18 @@ from controllers.openapi.auth.prepare import (
|
||||
load_app,
|
||||
load_app_access_mode,
|
||||
load_tenant,
|
||||
load_tenant_from_request,
|
||||
load_workspace_role,
|
||||
resolve_external_user,
|
||||
)
|
||||
from controllers.openapi.auth.verify import (
|
||||
check_acl,
|
||||
check_app_access,
|
||||
check_membership,
|
||||
check_app_api_enabled,
|
||||
check_private_app_permission,
|
||||
check_scope,
|
||||
check_workspace_member,
|
||||
check_workspace_mismatch,
|
||||
check_workspace_role,
|
||||
)
|
||||
from libs.oauth_bearer import TokenType
|
||||
|
||||
@ -30,13 +36,17 @@ account_pipeline = AuthPipeline(
|
||||
prepare=[
|
||||
When(PATH_HAS_APP_ID, then=load_app),
|
||||
When(PATH_HAS_APP_ID, then=load_tenant),
|
||||
load_account, # all tokens here are account tokens
|
||||
When(WORKSPACE_MEMBERSHIP_REQUIRED, then=load_tenant_from_request),
|
||||
load_account,
|
||||
When(WORKSPACE_SCOPED, then=load_workspace_role),
|
||||
When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
|
||||
check_scope,
|
||||
When(EDITION_CE & PATH_HAS_APP_ID, then=check_membership),
|
||||
When(EDITION_EE & PATH_HAS_APP_ID & ~WEBAPP_AUTH_ENABLED, then=check_app_access),
|
||||
When(WORKSPACE_SCOPED, then=check_workspace_member),
|
||||
When(PATH_HAS_APP_ID, then=check_workspace_mismatch),
|
||||
When(HAS_ALLOWED_ROLES, then=check_workspace_role),
|
||||
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
],
|
||||
@ -50,6 +60,7 @@ external_sso_pipeline = AuthPipeline(
|
||||
When(PATH_HAS_APP_ID, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
|
||||
check_scope,
|
||||
When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
|
||||
@ -50,4 +50,11 @@ EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
|
||||
|
||||
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
|
||||
|
||||
WORKSPACE_MEMBERSHIP_REQUIRED = request_cond(lambda ctx: ctx.workspace_membership)
|
||||
HAS_ALLOWED_ROLES = request_cond(lambda ctx: ctx.allowed_roles is not None)
|
||||
|
||||
# Caller must belong to the resolved tenant: either an app-scoped path (tenant
|
||||
# from the app) or an explicit workspace-membership path (tenant from request).
|
||||
WORKSPACE_SCOPED = PATH_HAS_APP_ID | WORKSPACE_MEMBERSHIP_REQUIRED
|
||||
|
||||
LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE)
|
||||
|
||||
@ -9,7 +9,7 @@ from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from configs import dify_config
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import Account, Tenant
|
||||
from models.account import Account, Tenant, TenantAccountRole
|
||||
from models.model import App, EndUser
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
@ -41,6 +41,8 @@ class RequestContext(BaseModel):
|
||||
token_type: TokenType
|
||||
scope: Scope | None = None
|
||||
path_params: dict[str, str]
|
||||
workspace_membership: bool = False
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None
|
||||
|
||||
|
||||
class AuthData(BaseModel):
|
||||
@ -56,10 +58,14 @@ class AuthData(BaseModel):
|
||||
external_identity: ExternalIdentity | None = None
|
||||
path_params: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None
|
||||
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
app_access_mode: WebAppAccessMode | None = None
|
||||
|
||||
tenant_role: TenantAccountRole | None = None
|
||||
|
||||
caller: Account | EndUser | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
|
||||
|
||||
@ -34,6 +34,7 @@ from libs.oauth_bearer import (
|
||||
reset_auth_ctx,
|
||||
set_auth_ctx,
|
||||
)
|
||||
from models.account import TenantAccountRole
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
||||
|
||||
@ -56,11 +57,15 @@ class AuthPipeline:
|
||||
view: Callable,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Any:
|
||||
req_ctx = RequestContext(
|
||||
token_type=identity.token_type,
|
||||
scope=scope,
|
||||
path_params=dict(request.view_args or {}),
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
data = AuthData(
|
||||
@ -71,6 +76,7 @@ class AuthPipeline:
|
||||
scopes=frozenset(identity.scopes),
|
||||
tenants=dict(identity.verified_tenants),
|
||||
required_scope=scope,
|
||||
allowed_roles=allowed_roles,
|
||||
path_params=dict(req_ctx.path_params),
|
||||
external_identity=(
|
||||
ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer)
|
||||
@ -121,6 +127,41 @@ class PipelineRouter:
|
||||
scope: Scope | None = None,
|
||||
allowed_token_types: frozenset[TokenType] | None = None,
|
||||
edition: frozenset[Edition] | None = None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Callable:
|
||||
return self._make_decorator(
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
def guard_workspace(
|
||||
self,
|
||||
*,
|
||||
scope: Scope | None = None,
|
||||
allowed_token_types: frozenset[TokenType] | None = None,
|
||||
edition: frozenset[Edition] | None = None,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Callable:
|
||||
return self._make_decorator(
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
workspace_membership=True,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
def _make_decorator(
|
||||
self,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
allowed_token_types: frozenset[TokenType] | None,
|
||||
edition: frozenset[Edition] | None,
|
||||
workspace_membership: bool,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None,
|
||||
) -> Callable:
|
||||
def decorator(view: Callable) -> Callable:
|
||||
@wraps(view)
|
||||
@ -132,6 +173,8 @@ class PipelineRouter:
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
return decorated
|
||||
@ -147,6 +190,8 @@ class PipelineRouter:
|
||||
scope: Scope | None,
|
||||
allowed_token_types: frozenset[TokenType] | None,
|
||||
edition: frozenset[Edition] | None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Any:
|
||||
# 404 not 403 — this edition doesn't expose the feature at all
|
||||
if edition is not None and current_edition() not in edition:
|
||||
@ -182,7 +227,15 @@ class PipelineRouter:
|
||||
if not license_checked and Edition.EE in route.required_edition:
|
||||
_check_license()
|
||||
|
||||
return route.pipeline._run(identity, args, kwargs, view, scope=scope)
|
||||
return route.pipeline._run(
|
||||
identity,
|
||||
args,
|
||||
kwargs,
|
||||
view,
|
||||
scope=scope,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
|
||||
def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool:
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
@ -13,16 +16,18 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppAcce
|
||||
|
||||
|
||||
def load_app(data: AuthData) -> None:
|
||||
if data.app is not None:
|
||||
return
|
||||
app_id = data.path_params["app_id"]
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
data.app = app
|
||||
|
||||
|
||||
def load_tenant(data: AuthData) -> None:
|
||||
if data.tenant is not None:
|
||||
return
|
||||
if data.app is None:
|
||||
raise InternalServerError("pipeline_invariant_violated: app not loaded before load_tenant")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, str(data.app.tenant_id))
|
||||
@ -31,7 +36,25 @@ def load_tenant(data: AuthData) -> None:
|
||||
data.tenant = tenant
|
||||
|
||||
|
||||
def load_tenant_from_request(data: AuthData) -> None:
|
||||
if data.tenant is not None:
|
||||
return
|
||||
workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
|
||||
if not workspace_id:
|
||||
raise NotFound("workspace not found")
|
||||
try:
|
||||
uuid.UUID(workspace_id)
|
||||
except ValueError:
|
||||
raise NotFound("workspace not found")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, workspace_id)
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise NotFound("workspace not found")
|
||||
data.tenant = tenant
|
||||
|
||||
|
||||
def load_account(data: AuthData) -> None:
|
||||
if data.caller is not None:
|
||||
return
|
||||
account = AccountService.get_account_by_id(db.session, str(data.account_id))
|
||||
if account is None:
|
||||
raise Unauthorized("account not found")
|
||||
@ -41,6 +64,19 @@ def load_account(data: AuthData) -> None:
|
||||
data.caller_kind = "account"
|
||||
|
||||
|
||||
def load_workspace_role(data: AuthData) -> None:
|
||||
if data.tenant_role is not None:
|
||||
return
|
||||
if data.tenant is None or data.account_id is None:
|
||||
return
|
||||
if data.caller is not None and getattr(data.caller, "status", None) != "active":
|
||||
return
|
||||
role = TenantService.get_account_role_in_tenant(db.session, str(data.account_id), str(data.tenant.id))
|
||||
if role is None:
|
||||
return
|
||||
data.tenant_role = role
|
||||
|
||||
|
||||
def resolve_external_user(data: AuthData) -> None:
|
||||
if data.tenant is None or data.app is None or data.external_identity is None:
|
||||
raise Unauthorized("missing context for external user resolution")
|
||||
|
||||
@ -1,77 +0,0 @@
|
||||
"""Workspace role gate.
|
||||
|
||||
Layered on top of `validate_bearer` + `accept_subjects(SubjectType.ACCOUNT)`
|
||||
for routes whose access depends on the caller's `TenantAccountJoin.role`
|
||||
in the workspace named by the `workspace_id` path parameter.
|
||||
|
||||
Usage::
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members")
|
||||
class Members(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
@require_workspace_role() # any member
|
||||
def get(self, workspace_id: str): ...
|
||||
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def post(self, workspace_id: str): ...
|
||||
|
||||
Non-member callers get 404 (matching `GET /openapi/v1/workspaces/<id>`)
|
||||
so workspace IDs do not leak across tenants. A member without one of the
|
||||
allowed roles gets 403.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TypeVar
|
||||
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import try_get_auth_ctx
|
||||
from models.account import TenantAccountRole
|
||||
from services.account_service import TenantService
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
|
||||
def require_workspace_role(*allowed_roles: TenantAccountRole) -> Callable[[F], F]:
|
||||
"""Gate a route on the caller's role in ``workspace_id``.
|
||||
|
||||
Pass no roles to require only membership. Pass one or more roles to
|
||||
require the caller's role be in that set.
|
||||
"""
|
||||
|
||||
allowed = frozenset(allowed_roles)
|
||||
|
||||
def deco(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: object, **kwargs: object) -> object:
|
||||
ctx = try_get_auth_ctx()
|
||||
if ctx is None or ctx.account_id is None:
|
||||
raise RuntimeError(
|
||||
"require_workspace_role called without account-bearer context; "
|
||||
"stack validate_bearer + accept_subjects(SubjectType.ACCOUNT) above it"
|
||||
)
|
||||
|
||||
workspace_id = kwargs.get("workspace_id")
|
||||
if not workspace_id:
|
||||
raise RuntimeError("require_workspace_role expects a 'workspace_id' route parameter")
|
||||
|
||||
role = TenantService.get_account_role_in_tenant(db.session, str(ctx.account_id), str(workspace_id))
|
||||
|
||||
if role is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
if allowed and role not in allowed:
|
||||
raise Forbidden("insufficient workspace role")
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return deco
|
||||
@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Forbidden, NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType, check_workspace_membership
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
|
||||
|
||||
@ -17,17 +18,39 @@ def check_scope(data: AuthData) -> None:
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
def check_membership(data: AuthData) -> None:
|
||||
def check_workspace_member(data: AuthData) -> None:
|
||||
"""Assert the caller belongs to the resolved tenant.
|
||||
|
||||
`load_workspace_role` stashes the membership role (None when the caller is
|
||||
not a member or is inactive). A missing membership surfaces as 404, not
|
||||
403, so workspace IDs don't leak across tenants.
|
||||
"""
|
||||
if data.tenant_role is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
|
||||
def check_workspace_mismatch(data: AuthData) -> None:
|
||||
if data.tenant is None:
|
||||
raise Unauthorized("tenant unset")
|
||||
if data.account_id is None:
|
||||
raise Unauthorized("account_id unset")
|
||||
check_workspace_membership(
|
||||
account_id=data.account_id,
|
||||
tenant_id=data.tenant.id,
|
||||
token_hash=data.token_hash,
|
||||
membership_cache=data.tenants,
|
||||
)
|
||||
return
|
||||
request_workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
|
||||
if request_workspace_id and request_workspace_id != str(data.tenant.id):
|
||||
raise UnprocessableEntity("workspace_id does not match app's workspace")
|
||||
|
||||
|
||||
def check_workspace_role(data: AuthData) -> None:
|
||||
if data.allowed_roles is None:
|
||||
return
|
||||
if data.tenant_role is None:
|
||||
raise NotFound("workspace not found")
|
||||
if data.tenant_role not in data.allowed_roles:
|
||||
raise Forbidden("insufficient workspace role")
|
||||
|
||||
|
||||
def check_app_api_enabled(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
return
|
||||
if not data.app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
|
||||
|
||||
def check_app_access(data: AuthData) -> None:
|
||||
|
||||
@ -26,7 +26,12 @@ from werkzeug.exceptions import BadRequest
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
AccountPayload,
|
||||
@ -42,7 +47,6 @@ from controllers.openapi._models import (
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_CONSOLE,
|
||||
@ -50,6 +54,7 @@ from libs.rate_limit import (
|
||||
LIMIT_LOOKUP_PUBLIC,
|
||||
rate_limit,
|
||||
)
|
||||
from models import Account
|
||||
from services.account_service import TenantService
|
||||
from services.oauth_device_flow import (
|
||||
ACCOUNT_ISSUER_SENTINEL,
|
||||
@ -206,11 +211,12 @@ class DeviceApproveApi(Resource):
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant: str, account: Account):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
account, tenant = current_account_with_tenant()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
found = store.load_by_user_code(user_code)
|
||||
|
||||
@ -5,9 +5,8 @@ endpoints. Account bearers (dfoa_) see every tenant they're a member of.
|
||||
External SSO bearers (dfoe_) have no account_id and so see an empty list —
|
||||
that matches /openapi/v1/account.
|
||||
|
||||
Member-management endpoints are gated by both `accept_subjects` (SSO out)
|
||||
and `require_workspace_role` (membership / role lookup against the path's
|
||||
``workspace_id``).
|
||||
Member-management endpoints use ``guard_workspace`` which enforces
|
||||
workspace membership and optional role requirements via the auth pipeline.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -37,7 +36,6 @@ from controllers.openapi._models import (
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.role_gate import require_workspace_role
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import Account, Tenant, TenantAccountJoin
|
||||
@ -152,8 +150,7 @@ class WorkspaceSwitchApi(Resource):
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role()
|
||||
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
account = _load_account(auth_data.account_id)
|
||||
|
||||
@ -179,8 +176,7 @@ class WorkspaceMembersApi(Resource):
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(MemberListQuery))
|
||||
@openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role()
|
||||
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData):
|
||||
try:
|
||||
query = MemberListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -202,8 +198,11 @@ class WorkspaceMembersApi(Resource):
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__])
|
||||
@openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberInvitePayload)
|
||||
inviter = _load_account(auth_data.account_id)
|
||||
@ -253,8 +252,11 @@ class WorkspaceMemberApi(Resource):
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
operator = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
@ -284,8 +286,11 @@ class WorkspaceMemberRoleApi(Resource):
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__])
|
||||
@openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberRoleUpdatePayload)
|
||||
operator = _load_account(auth_data.account_id)
|
||||
|
||||
@ -28,7 +28,7 @@ from core.errors.error import (
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.helper.trace_id_helper import get_external_trace_id, get_trace_session_id, omit_trace_session_id_from_payload
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
@ -41,12 +41,22 @@ from services.errors.llm import InvokeRateLimitError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_agent_app_streaming(*, app_mode: AppMode, response_mode: str | None) -> bool:
|
||||
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
|
||||
if app_mode != AppMode.AGENT:
|
||||
return response_mode == "streaming"
|
||||
if response_mode == "blocking":
|
||||
raise BadRequest("Agent App only supports streaming response mode.")
|
||||
return True
|
||||
|
||||
|
||||
class CompletionRequestPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str = Field(default="")
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
retriever_from: str = Field(default="dev")
|
||||
trace_session_id: str | None = Field(default=None, description="Trace session ID for observability grouping")
|
||||
|
||||
|
||||
class ChatRequestPayload(BaseModel):
|
||||
@ -58,6 +68,7 @@ class ChatRequestPayload(BaseModel):
|
||||
retriever_from: str = Field(default="dev")
|
||||
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
|
||||
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
|
||||
trace_session_id: str | None = Field(default=None, description="Trace session ID for observability grouping")
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@classmethod
|
||||
@ -103,9 +114,14 @@ class CompletionApi(Resource):
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise AppUnavailableError()
|
||||
|
||||
payload = CompletionRequestPayload.model_validate(service_api_ns.payload or {})
|
||||
payload = CompletionRequestPayload.model_validate(
|
||||
omit_trace_session_id_from_payload(service_api_ns.payload) or {}
|
||||
)
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
trace_session_id = get_trace_session_id(request)
|
||||
if trace_session_id:
|
||||
args["trace_session_id"] = trace_session_id
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
@ -197,17 +213,20 @@ class ChatApi(Resource):
|
||||
Supports conversation management and both blocking and streaming response modes.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
|
||||
payload = ChatRequestPayload.model_validate(omit_trace_session_id_from_payload(service_api_ns.payload) or {})
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
trace_session_id = get_trace_session_id(request)
|
||||
if trace_session_id:
|
||||
args["trace_session_id"] = trace_session_id
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
streaming = _resolve_agent_app_streaming(app_mode=app_mode, response_mode=payload.response_mode)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
@ -262,7 +281,7 @@ class ChatStopApi(Resource):
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""Stop a running chat message generation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppTaskService.stop_task(
|
||||
|
||||
@ -155,7 +155,7 @@ class ConversationApi(Resource):
|
||||
Supports pagination using last_id and limit parameters.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
query_args = ConversationListQuery.model_validate(request.args.to_dict())
|
||||
@ -199,7 +199,7 @@ class ConversationDetailApi(Resource):
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
"""Delete a specific conversation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -228,7 +228,7 @@ class ConversationRenameApi(Resource):
|
||||
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
"""Rename a conversation or auto-generate a name."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
@ -7,6 +7,7 @@ paused human input forms in workflow/chatflow runs.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
@ -18,6 +19,7 @@ from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from graphon.nodes.human_input.entities import FormInputConfig
|
||||
from libs.helper import to_timestamp
|
||||
from models.model import App, EndUser
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
@ -28,11 +30,11 @@ logger = logging.getLogger(__name__)
|
||||
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
def _jsonify_form_definition(form: Form, *, inputs: Sequence[FormInputConfig] = ()) -> Response:
|
||||
definition_payload = form.get_definition().model_dump(mode="json")
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"inputs": [form_input.model_dump(mode="json") for form_input in inputs],
|
||||
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": to_timestamp(form.expiration_time),
|
||||
@ -75,7 +77,8 @@ class WorkflowHumanInputFormApi(Resource):
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_service_api(form)
|
||||
service.ensure_form_active(form)
|
||||
return _jsonify_form_definition(form)
|
||||
inputs = service.resolve_form_inputs(form)
|
||||
return _jsonify_form_definition(form, inputs=inputs)
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@service_api_ns.doc("submit_human_input_form")
|
||||
|
||||
@ -56,7 +56,7 @@ class MessageListApi(Resource):
|
||||
Retrieves messages with pagination support using first_id.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
query_args = MessageListQuery.model_validate(request.args.to_dict())
|
||||
@ -167,7 +167,7 @@ class MessageSuggestedApi(Resource):
|
||||
"""
|
||||
message_id_str = str(message_id)
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
try:
|
||||
|
||||
@ -30,7 +30,7 @@ from core.errors.error import (
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.helper.trace_id_helper import get_external_trace_id, get_trace_session_id, omit_trace_session_id_from_payload
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.base import ResponseModel
|
||||
@ -54,6 +54,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class WorkflowRunPayload(WorkflowRunPayloadBase):
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
trace_session_id: str | None = Field(default=None, description="Trace session ID for observability grouping")
|
||||
|
||||
|
||||
class WorkflowLogQuery(BaseModel):
|
||||
@ -272,8 +273,11 @@ class WorkflowRunApi(Resource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
|
||||
payload = WorkflowRunPayload.model_validate(omit_trace_session_id_from_payload(service_api_ns.payload) or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
trace_session_id = get_trace_session_id(request)
|
||||
if trace_session_id:
|
||||
args["trace_session_id"] = trace_session_id
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
@ -328,8 +332,11 @@ class WorkflowRunByIdApi(Resource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
|
||||
payload = WorkflowRunPayload.model_validate(omit_trace_session_id_from_payload(service_api_ns.payload) or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
trace_session_id = get_trace_session_id(request)
|
||||
if trace_session_id:
|
||||
args["trace_session_id"] = trace_session_id
|
||||
|
||||
# Add workflow_id to args for AppGenerateService
|
||||
args["workflow_id"] = workflow_id
|
||||
|
||||
@ -23,6 +23,7 @@ from . import (
|
||||
feature,
|
||||
files,
|
||||
forgot_password,
|
||||
human_input_file_upload,
|
||||
human_input_form,
|
||||
login,
|
||||
message,
|
||||
@ -46,6 +47,7 @@ __all__ = [
|
||||
"feature",
|
||||
"files",
|
||||
"forgot_password",
|
||||
"human_input_file_upload",
|
||||
"human_input_form",
|
||||
"login",
|
||||
"message",
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Any, cast
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from werkzeug.exceptions import BadRequest, Unauthorized
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from constants import HEADER_NAME_APP_CODE
|
||||
from controllers.common import fields
|
||||
@ -58,9 +58,6 @@ class AppParameterApi(WebApiResource):
|
||||
)
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
"""Retrieve app parameters."""
|
||||
if not app_model.enable_site:
|
||||
raise BadRequest("Site is disabled.")
|
||||
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
@ -37,6 +37,15 @@ from services.errors.llm import InvokeRateLimitError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_agent_app_streaming(*, app_mode: AppMode, response_mode: str | None) -> bool:
|
||||
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
|
||||
if app_mode != AppMode.AGENT:
|
||||
return response_mode == "streaming"
|
||||
if response_mode == "blocking":
|
||||
raise BadRequest("Agent App only supports streaming response mode.")
|
||||
return True
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the completion")
|
||||
query: str = Field(default="", description="Query text for completion")
|
||||
@ -171,13 +180,13 @@ class ChatApi(WebApiResource):
|
||||
)
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
payload = ChatMessagePayload.model_validate(web_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
streaming = _resolve_agent_app_streaming(app_mode=app_mode, response_mode=payload.response_mode)
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
@ -228,7 +237,7 @@ class ChatStopApi(WebApiResource):
|
||||
@web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppTaskService.stop_task(
|
||||
|
||||
@ -83,7 +83,7 @@ class ConversationListApi(WebApiResource):
|
||||
)
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
raw_args = request.args.to_dict()
|
||||
@ -129,7 +129,7 @@ class ConversationApi(WebApiResource):
|
||||
)
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -168,7 +168,7 @@ class ConversationRenameApi(WebApiResource):
|
||||
)
|
||||
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -206,7 +206,7 @@ class ConversationPinApi(WebApiResource):
|
||||
@web_ns.response(200, "Conversation pinned successfully", web_ns.models[ResultResponse.__name__])
|
||||
def patch(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -237,7 +237,7 @@ class ConversationUnPinApi(WebApiResource):
|
||||
@web_ns.response(200, "Conversation unpinned successfully", web_ns.models[ResultResponse.__name__])
|
||||
def patch(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
212
api/controllers/web/human_input_file_upload.py
Normal file
212
api/controllers/web/human_input_file_upload.py
Normal file
@ -0,0 +1,212 @@
|
||||
"""HITL human input form file uploads.
|
||||
|
||||
This controller exposes a single public upload endpoint for both local files and
|
||||
remote URLs. The caller always submits a multipart form: when a non-empty
|
||||
``url`` field is present, the request follows the remote fetch flow; otherwise it
|
||||
falls back to the local file upload flow.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ConfigDict, Field, HttpUrl
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
RemoteFileUploadError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse, FileWithSignedUrl
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.exception import BaseHTTPException
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.file_service import FileService
|
||||
from services.human_input_file_upload_service import (
|
||||
HITL_UPLOAD_TOKEN_PREFIX,
|
||||
HumanInputFileUploadService,
|
||||
InvalidUploadTokenError,
|
||||
)
|
||||
|
||||
|
||||
class InvalidUploadTokenBadRequestError(BaseHTTPException):
|
||||
error_code = "invalid_upload_token"
|
||||
description = "Invalid upload token."
|
||||
code = 400
|
||||
|
||||
|
||||
class InvalidUploadTokenUnauthorizedError(BaseHTTPException):
|
||||
error_code = "invalid_upload_token"
|
||||
description = "Upload token is required."
|
||||
code = 401
|
||||
|
||||
|
||||
class InvalidUploadTokenForbiddenError(BaseHTTPException):
|
||||
error_code = "invalid_upload_token"
|
||||
description = "Upload token is invalid or expired."
|
||||
code = 403
|
||||
|
||||
|
||||
class HumanInputFileUploadFormPayload(BaseModel):
|
||||
"""Parsed multipart form fields for HITL uploads."""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
url: HttpUrl | None = Field(default=None, description="Remote file URL")
|
||||
|
||||
|
||||
register_schema_models(web_ns, HumanInputFileUploadFormPayload, FileResponse, FileWithSignedUrl)
|
||||
|
||||
|
||||
def _create_upload_service() -> HumanInputFileUploadService:
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
workflow_run_repository = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory)
|
||||
return HumanInputFileUploadService(
|
||||
session_factory=session_factory,
|
||||
workflow_run_repository=workflow_run_repository,
|
||||
)
|
||||
|
||||
|
||||
def _extract_hitl_upload_token() -> str:
|
||||
"""Read HITL upload token from Authorization without invoking other bearer auth chains."""
|
||||
|
||||
authorization = request.headers.get("Authorization")
|
||||
if authorization is None:
|
||||
raise InvalidUploadTokenUnauthorizedError()
|
||||
|
||||
parts = authorization.split()
|
||||
if len(parts) != 2:
|
||||
raise InvalidUploadTokenUnauthorizedError()
|
||||
|
||||
scheme, token = parts
|
||||
if scheme.lower() != "bearer":
|
||||
raise InvalidUploadTokenBadRequestError()
|
||||
if not token:
|
||||
raise InvalidUploadTokenUnauthorizedError()
|
||||
if not token.startswith(HITL_UPLOAD_TOKEN_PREFIX):
|
||||
raise InvalidUploadTokenBadRequestError()
|
||||
return token
|
||||
|
||||
|
||||
def _validate_context(service: HumanInputFileUploadService, token: str):
|
||||
try:
|
||||
return service.validate_upload_token(token)
|
||||
except InvalidUploadTokenError as exc:
|
||||
raise InvalidUploadTokenForbiddenError() from exc
|
||||
|
||||
|
||||
def _parse_local_upload_file():
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
|
||||
raise FilenameNotExistsError()
|
||||
|
||||
return file
|
||||
|
||||
|
||||
def _parse_upload_form() -> HumanInputFileUploadFormPayload:
|
||||
return HumanInputFileUploadFormPayload.model_validate(request.form.to_dict(flat=True))
|
||||
|
||||
|
||||
def _upload_local_file(context):
|
||||
file = _parse_local_upload_file()
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename or "",
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=context.owner,
|
||||
source=None,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
except services.errors.file.BlockedFileExtensionError as exc:
|
||||
raise BlockedFileExtensionError() from exc
|
||||
|
||||
response = FileResponse.model_validate(upload_file, from_attributes=True)
|
||||
return upload_file.id, response
|
||||
|
||||
|
||||
def _upload_remote_file(context, url: str):
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as exc:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(exc)}")
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
raise FileTooLargeError()
|
||||
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=context.owner,
|
||||
source_url=url,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
except services.errors.file.BlockedFileExtensionError as exc:
|
||||
raise BlockedFileExtensionError() from exc
|
||||
|
||||
response = FileWithSignedUrl(
|
||||
id=upload_file.id,
|
||||
name=upload_file.name,
|
||||
size=upload_file.size,
|
||||
extension=upload_file.extension,
|
||||
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
mime_type=upload_file.mime_type,
|
||||
created_by=upload_file.created_by,
|
||||
created_at=int(upload_file.created_at.timestamp()),
|
||||
)
|
||||
return upload_file.id, response
|
||||
|
||||
|
||||
@web_ns.route("/human-input-forms/files")
|
||||
@web_ns.response(201, "File uploaded successfully", web_ns.models[FileResponse.__name__])
|
||||
class HumanInputFileUploadApi(Resource):
|
||||
def post(self):
|
||||
"""Upload one local file or remote URL file for a HITL human input form."""
|
||||
|
||||
token = _extract_hitl_upload_token()
|
||||
upload_service = _create_upload_service()
|
||||
context = _validate_context(upload_service, token)
|
||||
form = _parse_upload_form()
|
||||
|
||||
# The browser always submits multipart/form-data. A non-empty `url`
|
||||
# switches the endpoint into the remote-fetch flow; otherwise the
|
||||
# request must carry a local `file`.
|
||||
if form.url is not None:
|
||||
file_id, response = _upload_remote_file(context=context, url=str(form.url))
|
||||
else:
|
||||
file_id, response = _upload_local_file(context=context)
|
||||
|
||||
upload_service.record_upload_file(context=context, file_id=file_id)
|
||||
return response.model_dump(mode="json"), 201
|
||||
@ -4,27 +4,42 @@ Web App Human Input Form APIs.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
||||
from controllers.web.site import serialize_app_site_payload
|
||||
from extensions.ext_database import db
|
||||
from graphon.nodes.human_input.entities import FormInputConfig
|
||||
from libs.helper import RateLimiter, extract_remote_ip, to_timestamp
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.human_input_file_upload_service import HumanInputFileUploadService
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputUploadTokenResponse(BaseModel):
|
||||
upload_token: str
|
||||
expires_at: int
|
||||
|
||||
|
||||
register_schema_models(web_ns, HumanInputUploadTokenResponse)
|
||||
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
@ -35,6 +50,20 @@ _FORM_ACCESS_RATE_LIMITER = RateLimiter(
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
)
|
||||
_FORM_UPLOAD_TOKEN_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_upload_token_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def _create_upload_service() -> HumanInputFileUploadService:
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
workflow_run_repository = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory)
|
||||
return HumanInputFileUploadService(
|
||||
session_factory=session_factory,
|
||||
workflow_run_repository=workflow_run_repository,
|
||||
)
|
||||
|
||||
|
||||
class FormDefinitionPayload(TypedDict):
|
||||
@ -46,12 +75,17 @@ class FormDefinitionPayload(TypedDict):
|
||||
site: NotRequired[dict]
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
|
||||
def _jsonify_form_definition(
|
||||
form: Form,
|
||||
*,
|
||||
inputs: Sequence[FormInputConfig] = (),
|
||||
site_payload: dict | None = None,
|
||||
) -> Response:
|
||||
"""Return the form payload (optionally with site) as a JSON response."""
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
definition_payload = form.get_definition().model_dump(mode="json")
|
||||
payload: FormDefinitionPayload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"inputs": [i.model_dump(mode="json") for i in inputs],
|
||||
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": to_timestamp(form.expiration_time),
|
||||
@ -61,6 +95,33 @@ def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Re
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/<string:form_token>/upload-token")
|
||||
class HumanInputFormUploadTokenApi(Resource):
|
||||
"""API for issuing HITL upload tokens for active human input forms."""
|
||||
|
||||
def post(self, form_token: str):
|
||||
"""
|
||||
Issue an upload token for a human input form.
|
||||
|
||||
POST /api/form/human_input/<form_token>/upload-token
|
||||
"""
|
||||
ip_address = extract_remote_ip(request)
|
||||
if _FORM_UPLOAD_TOKEN_RATE_LIMITER.is_rate_limited(ip_address):
|
||||
raise WebFormRateLimitExceededError()
|
||||
_FORM_UPLOAD_TOKEN_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
try:
|
||||
token = _create_upload_service().issue_upload_token(form_token)
|
||||
except FormNotFoundError:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
response = HumanInputUploadTokenResponse(
|
||||
upload_token=token.upload_token,
|
||||
expires_at=to_timestamp(token.expires_at),
|
||||
)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/<string:form_token>")
|
||||
class HumanInputFormApi(Resource):
|
||||
"""API for getting and submitting human input forms via the web app."""
|
||||
@ -89,8 +150,13 @@ class HumanInputFormApi(Resource):
|
||||
|
||||
service.ensure_form_active(form)
|
||||
app_model, site = _get_app_site_from_form(form)
|
||||
inputs = service.resolve_form_inputs(form)
|
||||
|
||||
return _jsonify_form_definition(form, site_payload=serialize_app_site_payload(app_model, site, None))
|
||||
return _jsonify_form_definition(
|
||||
form,
|
||||
inputs=inputs,
|
||||
site_payload=serialize_app_site_payload(app_model, site, None),
|
||||
)
|
||||
|
||||
# def post(self, _app_model: App, _end_user: EndUser, form_token: str):
|
||||
def post(self, form_token: str):
|
||||
|
||||
@ -83,7 +83,7 @@ class MessageListApi(WebApiResource):
|
||||
)
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
raw_args = request.args.to_dict()
|
||||
@ -225,7 +225,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
)
|
||||
def get(self, app_model: App, end_user: EndUser, message_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
|
||||
@ -9,7 +9,7 @@ from controllers.common.errors import (
|
||||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
@ -60,10 +60,10 @@ class RemoteFileInfoApi(WebApiResource):
|
||||
HTTPException: If the remote file cannot be accessed
|
||||
"""
|
||||
decoded_url = helpers.decode_remote_url(url, request.query_string)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
resp = remote_fetcher.make_request("HEAD", decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# failed back to get method
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp = remote_fetcher.make_request("GET", decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
info = RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
@ -112,9 +112,9 @@ class RemoteFileUploadApi(WebApiResource):
|
||||
url = str(payload.url)
|
||||
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
resp = remote_fetcher.make_request("HEAD", url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
resp = remote_fetcher.make_request("GET", url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
@ -125,7 +125,7 @@ class RemoteFileUploadApi(WebApiResource):
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
raise FileTooLargeError
|
||||
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
content = resp.content if resp.request.method == "GET" else remote_fetcher.make_request("GET", url).content
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
|
||||
@ -237,7 +237,9 @@ class EasyUIBasedAppConfig(AppConfig):
|
||||
"""
|
||||
|
||||
app_model_config_from: EasyUIBasedAppModelConfigFrom
|
||||
app_model_config_id: str
|
||||
# Optional: an Agent App has no legacy app_model_config row, so the id may be
|
||||
# absent (persistence then stores NULL for the conversation's id).
|
||||
app_model_config_id: str | None = None
|
||||
app_model_config_dict: dict[str, Any]
|
||||
model: ModelConfigEntity
|
||||
prompt_template: PromptTemplateEntity
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity
|
||||
from graphon.variables.input_entities import VariableEntity
|
||||
@ -20,10 +22,32 @@ class WorkflowVariablesConfigManager:
|
||||
|
||||
# variables
|
||||
for variable in user_input_form:
|
||||
cls._normalize_json_schema(variable)
|
||||
variables.append(VariableEntity.model_validate(variable))
|
||||
|
||||
return variables
|
||||
|
||||
@staticmethod
|
||||
def _normalize_json_schema(variable: dict[str, Any]) -> None:
|
||||
"""
|
||||
Normalize ``json_schema`` from a JSON string to a dict.
|
||||
|
||||
The workflow graph is stored as JSON in the database. When a JSON
|
||||
object variable carries a ``json_schema`` field, nested dicts are
|
||||
preserved correctly, but older data or certain serialization paths
|
||||
may store it as a JSON *string* instead of a native dict.
|
||||
|
||||
``VariableEntity.json_schema`` expects ``dict | None``, so we
|
||||
deserialize the string here before handing it to Pydantic.
|
||||
"""
|
||||
json_schema = variable.get("json_schema")
|
||||
if isinstance(json_schema, str):
|
||||
try:
|
||||
variable["json_schema"] = json.loads(json_schema)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Leave as-is; Pydantic validation will surface the error.
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]:
|
||||
"""
|
||||
|
||||
@ -40,7 +40,7 @@ from core.app.entities.task_entities import (
|
||||
ChatbotAppStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args, extract_trace_session_id_from_args
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
@ -64,6 +64,12 @@ from services.workflow_draft_variable_service import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_trace_session_id_from_debug_args(args: Mapping[str, Any] | Any) -> dict[str, str]:
|
||||
if isinstance(args, Mapping):
|
||||
return extract_trace_session_id_from_args(args)
|
||||
return extract_trace_session_id_from_args({"trace_session_id": getattr(args, "trace_session_id", None)})
|
||||
|
||||
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
_dialogue_count: int
|
||||
|
||||
@ -140,6 +146,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args.get("auto_generate_name", False),
|
||||
**extract_external_trace_id_from_args(args),
|
||||
**extract_trace_session_id_from_args(args),
|
||||
}
|
||||
|
||||
# get conversation
|
||||
@ -331,7 +338,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
extras={
|
||||
"auto_generate_conversation_name": False,
|
||||
**_extract_trace_session_id_from_debug_args(args),
|
||||
},
|
||||
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
@ -417,7 +427,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
extras={
|
||||
"auto_generate_conversation_name": False,
|
||||
**_extract_trace_session_id_from_debug_args(args),
|
||||
},
|
||||
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
|
||||
@ -131,6 +131,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
invoke_from=invoke_from,
|
||||
user_from=user_from,
|
||||
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
|
||||
)
|
||||
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
@ -139,6 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
@ -199,6 +201,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
root_node_id=root_node_id,
|
||||
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
0
api/core/app/apps/agent_app/__init__.py
Normal file
0
api/core/app/apps/agent_app/__init__.py
Normal file
106
api/core/app/apps/agent_app/app_config_manager.py
Normal file
106
api/core/app/apps/agent_app/app_config_manager.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""Build the EasyUI-style app config for an Agent App from its Agent Soul.
|
||||
|
||||
An Agent App has no legacy ``app_model_config``: its model / prompt live in the
|
||||
bound Agent Soul snapshot. To ride the existing chat message + SSE pipeline we
|
||||
synthesize an ``app_model_config``-shaped dict from the Soul (model + system
|
||||
prompt) plus any app-level feature flags (opening statement, follow-up, …)
|
||||
stored on ``app_model_config`` when present, then reuse the same sub-managers
|
||||
the chat app type uses.
|
||||
"""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager
|
||||
from core.app.app_config.entities import (
|
||||
EasyUIBasedAppConfig,
|
||||
EasyUIBasedAppModelConfigFrom,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
|
||||
|
||||
|
||||
class AgentAppConfig(EasyUIBasedAppConfig):
|
||||
"""Agent App config entity (EasyUI-shaped so it rides the chat pipeline).
|
||||
|
||||
``app_model_config_id`` is inherited as ``str | None``: an Agent App may have
|
||||
no legacy ``app_model_config`` row, in which case persistence stores ``NULL``
|
||||
for the conversation's ``app_model_config_id``.
|
||||
"""
|
||||
|
||||
|
||||
class AgentAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(
|
||||
cls,
|
||||
*,
|
||||
app_model: App,
|
||||
agent_soul: AgentSoulConfig,
|
||||
app_model_config: AppModelConfig | None = None,
|
||||
conversation: Conversation | None = None,
|
||||
) -> AgentAppConfig:
|
||||
"""Build the Agent App config from the Agent Soul (+ optional feature flags)."""
|
||||
config_dict = cls._synthesize_config_dict(agent_soul, app_model_config)
|
||||
# The synthesized dict is shaped like an app_model_config; the EasyUI
|
||||
# sub-managers type their param as AppModelConfigDict (a TypedDict).
|
||||
typed_config = cast(AppModelConfigDict, config_dict)
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
app_config = AgentAppConfig(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
app_mode=app_mode,
|
||||
# The config is derived from the Agent Soul snapshot, not a legacy
|
||||
# app_model_config row; the id is informational only.
|
||||
app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG,
|
||||
app_model_config_id=app_model_config.id if app_model_config else None,
|
||||
app_model_config_dict=config_dict,
|
||||
model=ModelConfigManager.convert(config=typed_config),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=typed_config),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
dataset=DatasetConfigManager.convert(config=typed_config),
|
||||
additional_features=cls.convert_features(config_dict, app_mode),
|
||||
)
|
||||
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
|
||||
config=typed_config
|
||||
)
|
||||
return app_config
|
||||
|
||||
@staticmethod
|
||||
def _synthesize_config_dict(
|
||||
agent_soul: AgentSoulConfig,
|
||||
app_model_config: AppModelConfig | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Shape a Soul + feature flags into an ``app_model_config``-style dict.
|
||||
|
||||
Feature flags (opening statement / follow-up / tts / stt / citations /
|
||||
moderation / annotation) come from ``app_model_config`` when present
|
||||
(Q3: stored there), otherwise defaults; model + prompt always come from
|
||||
the Agent Soul (the single source of truth for those).
|
||||
"""
|
||||
base: dict[str, Any] = dict(app_model_config.to_dict()) if app_model_config else {}
|
||||
|
||||
model = agent_soul.model
|
||||
if model is not None:
|
||||
base["model"] = {
|
||||
"provider": model.model_provider,
|
||||
"name": model.model,
|
||||
"mode": "chat",
|
||||
"completion_params": model.model_settings.model_dump(mode="json", exclude_none=True),
|
||||
}
|
||||
# The Agent Soul system prompt rides the EasyUI "simple" prompt slot; the
|
||||
# agent backend is the real prompt authority, this only feeds the chat
|
||||
# pipeline's bookkeeping (token counting, persistence).
|
||||
base["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
|
||||
base["pre_prompt"] = agent_soul.prompt.system_prompt or ""
|
||||
# Agent App takes the user message directly; no completion-style inputs form.
|
||||
base.setdefault("user_input_form", [])
|
||||
return base
|
||||
|
||||
|
||||
__all__ = ["AgentAppConfig", "AgentAppConfigManager"]
|
||||
329
api/core/app/apps/agent_app/app_generator.py
Normal file
329
api/core/app/apps/agent_app/app_generator.py
Normal file
@ -0,0 +1,329 @@
|
||||
"""Agent App generator: orchestrate one conversation turn for an Agent App.
|
||||
|
||||
Mirrors the agent_chat generator (conversation + message + queue + streamed
|
||||
response over the EasyUI chat pipeline), but the backing config comes from the
|
||||
bound Agent Soul and the answer is produced by ``AgentAppRunner`` calling the
|
||||
dify-agent backend rather than an in-process LLM/ReAct loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
|
||||
from clients.agent_backend import AgentBackendRunEventAdapter
|
||||
from clients.agent_backend.factory import create_agent_backend_run_client
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.apps.agent_app.app_config_manager import AgentAppConfigManager
|
||||
from core.app.apps.agent_app.app_runner import AgentAppRunner
|
||||
from core.app.apps.agent_app.generate_response_converter import AgentAppGenerateResponseConverter
|
||||
from core.app.apps.agent_app.runtime_request_builder import AgentAppRuntimeRequestBuilder
|
||||
from core.app.apps.agent_app.session_store import AgentAppRuntimeSessionStore
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AgentAppGenerateEntity,
|
||||
DifyRunContext,
|
||||
InvokeFrom,
|
||||
UserFrom,
|
||||
)
|
||||
from core.app.llm.model_access import build_dify_model_access
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, Message
|
||||
from models.agent import Agent, AgentConfigSnapshot, AgentScope, AgentSource, AgentStatus
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentAppGeneratorError(ValueError):
|
||||
"""Raised when an Agent App turn cannot be set up."""
|
||||
|
||||
|
||||
class AgentAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[Mapping | str, None, None]:
|
||||
if not streaming:
|
||||
raise AgentAppGeneratorError("Agent App only supports streaming mode")
|
||||
|
||||
query = args.get("query")
|
||||
if not isinstance(query, str) or not query.strip():
|
||||
raise AgentAppGeneratorError("query is required")
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
# Resolve the bound roster Agent + its published Agent Soul snapshot.
|
||||
agent, snapshot, agent_soul = self._resolve_agent(app_model)
|
||||
|
||||
conversation = None
|
||||
conversation_id = args.get("conversation_id")
|
||||
if conversation_id:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
|
||||
# Build the EasyUI-shaped config from the Agent Soul so the chat pipeline
|
||||
# can persist usage; the answer itself comes from the agent backend.
|
||||
app_model_config = app_model.app_model_config
|
||||
app_config = AgentAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
agent_soul=agent_soul,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
)
|
||||
model_conf = ModelConfigConverter.convert(app_config)
|
||||
|
||||
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
|
||||
|
||||
application_generate_entity = AgentAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=model_conf,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=[],
|
||||
parent_message_id=(
|
||||
args.get("parent_message_id")
|
||||
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
|
||||
else UUID_NIL
|
||||
),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras={
|
||||
"auto_generate_conversation_name": args.get("auto_generate_name", True),
|
||||
},
|
||||
call_depth=0,
|
||||
trace_manager=trace_manager,
|
||||
agent_id=agent.id,
|
||||
agent_config_snapshot_id=snapshot.id,
|
||||
)
|
||||
|
||||
conversation, message = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
context = contextvars.copy_context()
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": context,
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
"user_from": UserFrom.ACCOUNT if isinstance(user, Account) else UserFrom.END_USER,
|
||||
},
|
||||
)
|
||||
worker_thread.start()
|
||||
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
return AgentAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
*,
|
||||
flask_app: Flask,
|
||||
context: contextvars.Context,
|
||||
application_generate_entity: AgentAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
user_from: UserFrom,
|
||||
) -> None:
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
try:
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
app_config = application_generate_entity.app_config
|
||||
|
||||
# Apply app-level input guards (content moderation + annotation
|
||||
# reply) before reaching the Agent backend, mirroring the EasyUI
|
||||
# chat / agent-chat runners. These can short-circuit the turn.
|
||||
app_model = db.session.get(App, app_config.app_id)
|
||||
if app_model is None:
|
||||
raise AgentAppGeneratorError("App not found")
|
||||
handled, query = self._run_input_guards(
|
||||
application_generate_entity=application_generate_entity,
|
||||
app_model=app_model,
|
||||
message=message,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
if handled:
|
||||
return
|
||||
|
||||
dify_context = DifyRunContext(
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_id=app_config.app_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
)
|
||||
credentials_provider, _ = build_dify_model_access(dify_context)
|
||||
_, _, agent_soul = self._resolve_agent_by_id(
|
||||
tenant_id=app_config.tenant_id,
|
||||
agent_id=application_generate_entity.agent_id,
|
||||
snapshot_id=application_generate_entity.agent_config_snapshot_id,
|
||||
)
|
||||
|
||||
runner = AgentAppRunner(
|
||||
request_builder=AgentAppRuntimeRequestBuilder(credentials_provider=credentials_provider),
|
||||
agent_backend_client=create_agent_backend_run_client(
|
||||
base_url=dify_config.AGENT_BACKEND_BASE_URL,
|
||||
use_fake=dify_config.AGENT_BACKEND_USE_FAKE,
|
||||
fake_scenario=dify_config.AGENT_BACKEND_FAKE_SCENARIO,
|
||||
),
|
||||
event_adapter=AgentBackendRunEventAdapter(),
|
||||
session_store=AgentAppRuntimeSessionStore(),
|
||||
)
|
||||
runner.run(
|
||||
dify_context=dify_context,
|
||||
agent_id=application_generate_entity.agent_id,
|
||||
agent_config_snapshot_id=application_generate_entity.agent_config_snapshot_id,
|
||||
agent_soul=agent_soul,
|
||||
conversation_id=conversation.id,
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
model_name=application_generate_entity.model_conf.model,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error in Agent App generate worker")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _run_input_guards(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: AgentAppGenerateEntity,
|
||||
app_model: App,
|
||||
message: Message,
|
||||
queue_manager: AppQueueManager,
|
||||
) -> tuple[bool, str]:
|
||||
"""Apply input moderation + annotation reply before the backend call.
|
||||
|
||||
Returns ``(handled, query)``: when ``handled`` is True a direct answer
|
||||
has already been published (a blocked/preset moderation response or a
|
||||
matched annotation) and the backend turn must be skipped. Otherwise
|
||||
``query`` is the possibly moderation-overridden query to send onward.
|
||||
"""
|
||||
from core.app.apps.agent_app.app_runner import publish_text_answer
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
|
||||
app_config = application_generate_entity.app_config
|
||||
model_name = application_generate_entity.model_conf.model
|
||||
query = application_generate_entity.query
|
||||
|
||||
# content moderation (sensitive_word_avoidance); a blocked input yields a
|
||||
# preset answer, an "overridden" action returns a sanitized query.
|
||||
try:
|
||||
_, _, query = InputModeration().check(
|
||||
app_id=app_config.app_id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_config=app_config,
|
||||
inputs=dict(application_generate_entity.inputs),
|
||||
query=query or "",
|
||||
message_id=message.id,
|
||||
trace_manager=application_generate_entity.trace_manager,
|
||||
)
|
||||
except ModerationError as e:
|
||||
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=str(e))
|
||||
return True, query
|
||||
|
||||
# annotation reply: a matching annotation answers the turn deterministically.
|
||||
if query:
|
||||
annotation_reply = AnnotationReplyFeature().query(
|
||||
app_record=app_model,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
)
|
||||
if annotation_reply:
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=annotation_reply.content)
|
||||
return True, query
|
||||
|
||||
return False, query
|
||||
|
||||
def _resolve_agent(self, app_model: App) -> tuple[Agent, AgentConfigSnapshot, AgentSoulConfig]:
|
||||
agent = db.session.scalar(
|
||||
select(Agent).where(
|
||||
Agent.app_id == app_model.id,
|
||||
Agent.scope == AgentScope.ROSTER,
|
||||
Agent.source == AgentSource.AGENT_APP,
|
||||
Agent.status == AgentStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
if agent is None:
|
||||
raise AgentAppGeneratorError("Agent App has no bound Agent")
|
||||
return self._resolve_agent_by_id(
|
||||
tenant_id=app_model.tenant_id, agent_id=agent.id, snapshot_id=agent.active_config_snapshot_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_agent_by_id(
|
||||
*, tenant_id: str, agent_id: str, snapshot_id: str | None
|
||||
) -> tuple[Agent, AgentConfigSnapshot, AgentSoulConfig]:
|
||||
agent = db.session.scalar(select(Agent).where(Agent.id == agent_id, Agent.tenant_id == tenant_id))
|
||||
if agent is None:
|
||||
raise AgentAppGeneratorError("Agent not found")
|
||||
if not snapshot_id:
|
||||
raise AgentAppGeneratorError("Agent has no published version")
|
||||
snapshot = db.session.scalar(select(AgentConfigSnapshot).where(AgentConfigSnapshot.id == snapshot_id))
|
||||
if snapshot is None:
|
||||
raise AgentAppGeneratorError("Agent published version not found")
|
||||
agent_soul = AgentSoulConfig.model_validate(snapshot.config_snapshot_dict)
|
||||
return agent, snapshot, agent_soul
|
||||
|
||||
|
||||
__all__ = ["AgentAppGenerator", "AgentAppGeneratorError"]
|
||||
200
api/core/app/apps/agent_app/app_runner.py
Normal file
200
api/core/app/apps/agent_app/app_runner.py
Normal file
@ -0,0 +1,200 @@
|
||||
"""Agent App runner: drive one conversation turn through the dify-agent backend.
|
||||
|
||||
Unlike the legacy ``AgentChatAppRunner`` (which runs an in-process ReAct loop),
|
||||
this runner delegates to the Agent backend: build the run request from the
|
||||
Agent Soul + conversation, create the run, consume its event stream, and
|
||||
republish the assistant answer as chat queue events so the existing
|
||||
EasyUI chat task pipeline persists the message and streams SSE. The conversation
|
||||
``session_snapshot`` is saved on success for multi-turn continuity (S3).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from clients.agent_backend import (
|
||||
AgentBackendError,
|
||||
AgentBackendInternalEventType,
|
||||
AgentBackendRunClient,
|
||||
AgentBackendRunEventAdapter,
|
||||
AgentBackendRunSucceededInternalEvent,
|
||||
AgentBackendStreamInternalEvent,
|
||||
)
|
||||
from core.app.apps.agent_app.runtime_request_builder import (
|
||||
AgentAppRuntimeBuildContext,
|
||||
AgentAppRuntimeRequestBuilder,
|
||||
)
|
||||
from core.app.apps.agent_app.session_store import AgentAppRuntimeSessionStore, AgentAppSessionScope
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext
|
||||
from core.app.entities.queue_entities import QueueLLMChunkEvent, QueueMessageEndEvent
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def publish_text_answer(*, queue_manager: AppQueueManager, model_name: str, answer: str) -> None:
|
||||
"""Publish a complete assistant answer as one chunk + message-end.
|
||||
|
||||
The EasyUI chat task pipeline consumes a QueueLLMChunkEvent stream followed
|
||||
by a QueueMessageEndEvent; emitting the whole answer as a single chunk lets
|
||||
both the backend-produced answer and short-circuited answers (moderation /
|
||||
annotation reply) share the exact same persistence + SSE path.
|
||||
"""
|
||||
chunk = LLMResultChunk(
|
||||
model=model_name,
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=answer)),
|
||||
)
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
|
||||
queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_name,
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content=answer),
|
||||
usage=LLMUsage.empty_usage(),
|
||||
),
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
|
||||
class AgentAppRunner:
|
||||
"""Runs one Agent App conversation turn against the Agent backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_builder: AgentAppRuntimeRequestBuilder,
|
||||
agent_backend_client: AgentBackendRunClient,
|
||||
event_adapter: AgentBackendRunEventAdapter,
|
||||
session_store: AgentAppRuntimeSessionStore,
|
||||
) -> None:
|
||||
self._request_builder = request_builder
|
||||
self._agent_backend_client = agent_backend_client
|
||||
self._event_adapter = event_adapter
|
||||
self._session_store = session_store
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
dify_context: DifyRunContext,
|
||||
agent_id: str,
|
||||
agent_config_snapshot_id: str,
|
||||
agent_soul: AgentSoulConfig,
|
||||
conversation_id: str,
|
||||
query: str,
|
||||
message_id: str,
|
||||
model_name: str,
|
||||
queue_manager: AppQueueManager,
|
||||
) -> None:
|
||||
scope = AgentAppSessionScope(
|
||||
tenant_id=dify_context.tenant_id,
|
||||
app_id=dify_context.app_id,
|
||||
conversation_id=conversation_id,
|
||||
agent_id=agent_id,
|
||||
agent_config_snapshot_id=agent_config_snapshot_id,
|
||||
)
|
||||
session_snapshot = self._session_store.load_active_snapshot(scope)
|
||||
|
||||
runtime = self._request_builder.build(
|
||||
AgentAppRuntimeBuildContext(
|
||||
dify_context=dify_context,
|
||||
agent_id=agent_id,
|
||||
agent_config_snapshot_id=agent_config_snapshot_id,
|
||||
agent_soul=agent_soul,
|
||||
conversation_id=conversation_id,
|
||||
user_query=query,
|
||||
idempotency_key=message_id,
|
||||
session_snapshot=session_snapshot,
|
||||
)
|
||||
)
|
||||
|
||||
create_response = self._agent_backend_client.create_run(runtime.request)
|
||||
terminal = self._consume_stream(create_response.run_id, queue_manager=queue_manager)
|
||||
|
||||
if not isinstance(terminal, AgentBackendRunSucceededInternalEvent):
|
||||
error = getattr(terminal, "error", None) or "Agent backend run did not complete successfully."
|
||||
raise AgentBackendError(str(error))
|
||||
|
||||
answer = self._extract_answer(terminal.output)
|
||||
self._publish_answer(queue_manager=queue_manager, model_name=model_name, answer=answer)
|
||||
self._save_session(scope=scope, backend_run_id=terminal.run_id, snapshot=terminal.session_snapshot)
|
||||
|
||||
def _consume_stream(self, run_id: str, *, queue_manager: AppQueueManager):
|
||||
terminal = None
|
||||
for public_event in self._agent_backend_client.stream_events(run_id):
|
||||
if queue_manager.is_stopped():
|
||||
self._cancel_run(run_id)
|
||||
raise GenerateTaskStoppedError()
|
||||
for internal_event in self._event_adapter.adapt(public_event):
|
||||
if queue_manager.is_stopped():
|
||||
self._cancel_run(run_id)
|
||||
raise GenerateTaskStoppedError()
|
||||
if internal_event.type in (
|
||||
AgentBackendInternalEventType.RUN_STARTED,
|
||||
AgentBackendInternalEventType.STREAM_EVENT,
|
||||
):
|
||||
# Stream deltas are accumulated by the backend into the
|
||||
# terminal output; token-level forwarding is an S3 refinement.
|
||||
if isinstance(internal_event, AgentBackendStreamInternalEvent):
|
||||
continue
|
||||
continue
|
||||
terminal = internal_event
|
||||
break
|
||||
if terminal is not None:
|
||||
break
|
||||
return terminal
|
||||
|
||||
def _cancel_run(self, run_id: str) -> None:
|
||||
try:
|
||||
self._agent_backend_client.cancel_run(run_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to cancel stopped Agent App backend run: run_id=%s", run_id, exc_info=True)
|
||||
|
||||
def _publish_answer(self, *, queue_manager: AppQueueManager, model_name: str, answer: str) -> None:
|
||||
# MVP: emit the full answer as a single chunk + message-end. The chat
|
||||
# task pipeline streams the chunk over SSE and persists the message.
|
||||
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=answer)
|
||||
|
||||
def _save_session(self, *, scope: AgentAppSessionScope, backend_run_id: str, snapshot: Any) -> None:
|
||||
try:
|
||||
self._session_store.save_active_snapshot(scope=scope, backend_run_id=backend_run_id, snapshot=snapshot)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist Agent App conversation session snapshot: "
|
||||
"tenant_id=%s app_id=%s conversation_id=%s agent_id=%s",
|
||||
scope.tenant_id,
|
||||
scope.app_id,
|
||||
scope.conversation_id,
|
||||
scope.agent_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_answer(output: JsonValue) -> str:
|
||||
"""Normalize the backend's terminal output to assistant text.
|
||||
|
||||
Free-text Agent Apps return a plain string; if a structured output is
|
||||
configured the value is a JSON object, which we serialize so the chat
|
||||
message always has a string body.
|
||||
"""
|
||||
if isinstance(output, str):
|
||||
return output
|
||||
if isinstance(output, dict):
|
||||
text = output.get("text")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return json.dumps(output, ensure_ascii=False)
|
||||
return json.dumps(output, ensure_ascii=False)
|
||||
|
||||
|
||||
__all__ = ["AgentAppRunner", "publish_text_answer"]
|
||||
15
api/core/app/apps/agent_app/generate_response_converter.py
Normal file
15
api/core/app/apps/agent_app/generate_response_converter.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""Response converter for the Agent App type.
|
||||
|
||||
The Agent App streams the same chatbot response shape as the chat / agent-chat
|
||||
app types, so it reuses that converter wholesale; kept as a distinct subclass so
|
||||
the app type owns its converter and can diverge later.
|
||||
"""
|
||||
|
||||
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
|
||||
|
||||
|
||||
class AgentAppGenerateResponseConverter(AgentChatAppGenerateResponseConverter):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["AgentAppGenerateResponseConverter"]
|
||||
184
api/core/app/apps/agent_app/runtime_request_builder.py
Normal file
184
api/core/app/apps/agent_app/runtime_request_builder.py
Normal file
@ -0,0 +1,184 @@
|
||||
"""Build dify-agent run requests for one Agent App conversation turn.
|
||||
|
||||
Mirrors the workflow ``WorkflowAgentRuntimeRequestBuilder`` but for the Agent
|
||||
App surface: the user prompt is the chat message (no workflow-node job / no
|
||||
previous-node context), and multi-turn continuity flows through the
|
||||
conversation-keyed ``session_snapshot`` plus the history layer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
|
||||
from dify_agent.protocol import CreateRunRequest
|
||||
|
||||
from clients.agent_backend import (
|
||||
AgentBackendAgentAppRunInput,
|
||||
AgentBackendModelConfig,
|
||||
AgentBackendRunRequestBuilder,
|
||||
redact_for_agent_backend_log,
|
||||
)
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext
|
||||
from core.workflow.nodes.agent_v2.plugin_tools_builder import (
|
||||
WorkflowAgentPluginToolsBuilder,
|
||||
WorkflowAgentPluginToolsBuildError,
|
||||
)
|
||||
from core.workflow.nodes.agent_v2.runtime_request_builder import build_shell_layer_config
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
class AgentAppRuntimeRequestBuildError(ValueError):
|
||||
"""Raised when Agent App state cannot be mapped to a valid run request."""
|
||||
|
||||
def __init__(self, error_code: str, message: str) -> None:
|
||||
self.error_code = error_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CredentialsProvider(Protocol):
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AgentAppRuntimeBuildContext:
|
||||
dify_context: DifyRunContext
|
||||
agent_id: str
|
||||
agent_config_snapshot_id: str
|
||||
agent_soul: AgentSoulConfig
|
||||
conversation_id: str
|
||||
user_query: str
|
||||
idempotency_key: str
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AgentAppRuntimeRequest:
|
||||
request: CreateRunRequest
|
||||
redacted_request: dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class AgentAppRuntimeRequestBuilder:
|
||||
"""Build dify-agent run requests from Agent App conversation state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
credentials_provider: CredentialsProvider,
|
||||
request_builder: AgentBackendRunRequestBuilder | None = None,
|
||||
plugin_tools_builder: WorkflowAgentPluginToolsBuilder | None = None,
|
||||
) -> None:
|
||||
self._credentials_provider = credentials_provider
|
||||
self._request_builder = request_builder or AgentBackendRunRequestBuilder()
|
||||
self._plugin_tools_builder = plugin_tools_builder or WorkflowAgentPluginToolsBuilder()
|
||||
|
||||
def build(self, context: AgentAppRuntimeBuildContext) -> AgentAppRuntimeRequest:
|
||||
agent_soul = context.agent_soul
|
||||
if agent_soul.model is None:
|
||||
raise AgentAppRuntimeRequestBuildError(
|
||||
"agent_model_not_configured",
|
||||
"Agent App requires the Agent Soul model to be configured.",
|
||||
)
|
||||
|
||||
metadata = self._build_metadata(context)
|
||||
credentials = self._credentials_provider.fetch(agent_soul.model.model_provider, agent_soul.model.model)
|
||||
try:
|
||||
tools_layer = self._plugin_tools_builder.build(
|
||||
tenant_id=context.dify_context.tenant_id,
|
||||
app_id=context.dify_context.app_id,
|
||||
user_id=context.dify_context.user_id,
|
||||
tools=agent_soul.tools,
|
||||
invoke_from=context.dify_context.invoke_from,
|
||||
)
|
||||
except WorkflowAgentPluginToolsBuildError as error:
|
||||
raise AgentAppRuntimeRequestBuildError(error.error_code, str(error)) from error
|
||||
if tools_layer is not None or agent_soul.tools.cli_tools:
|
||||
metadata["agent_tools"] = {
|
||||
"dify_tool_count": len(tools_layer.tools) if tools_layer is not None else 0,
|
||||
"dify_tool_names": [tool.name or tool.tool_name for tool in tools_layer.tools]
|
||||
if tools_layer is not None
|
||||
else [],
|
||||
"cli_tool_count": len(agent_soul.tools.cli_tools),
|
||||
}
|
||||
|
||||
request = self._request_builder.build_for_agent_app(
|
||||
AgentBackendAgentAppRunInput(
|
||||
model=AgentBackendModelConfig(
|
||||
plugin_id=self._plugin_daemon_plugin_id(
|
||||
plugin_id=agent_soul.model.plugin_id,
|
||||
model_provider=agent_soul.model.model_provider,
|
||||
),
|
||||
model_provider=self._plugin_daemon_provider_name(agent_soul.model.model_provider),
|
||||
model=agent_soul.model.model,
|
||||
credentials=self._normalize_credentials(credentials),
|
||||
model_settings=agent_soul.model.model_settings.model_dump(mode="json", exclude_none=True),
|
||||
),
|
||||
execution_context=DifyExecutionContextLayerConfig(
|
||||
tenant_id=context.dify_context.tenant_id,
|
||||
user_id=context.dify_context.user_id,
|
||||
app_id=context.dify_context.app_id,
|
||||
conversation_id=context.conversation_id,
|
||||
agent_id=context.agent_id,
|
||||
agent_config_version_id=context.agent_config_snapshot_id,
|
||||
invoke_from="agent_app",
|
||||
),
|
||||
agent_soul_prompt=agent_soul.prompt.system_prompt or None,
|
||||
user_prompt=context.user_query,
|
||||
tools=tools_layer,
|
||||
include_shell=dify_config.AGENT_SHELL_ENABLED,
|
||||
shell_config=build_shell_layer_config(agent_soul),
|
||||
session_snapshot=context.session_snapshot,
|
||||
idempotency_key=context.idempotency_key,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
redacted = cast(dict[str, Any], redact_for_agent_backend_log(request))
|
||||
return AgentAppRuntimeRequest(request=request, redacted_request=redacted, metadata=metadata)
|
||||
|
||||
@staticmethod
|
||||
def _build_metadata(context: AgentAppRuntimeBuildContext) -> dict[str, Any]:
|
||||
return {
|
||||
"tenant_id": context.dify_context.tenant_id,
|
||||
"app_id": context.dify_context.app_id,
|
||||
"conversation_id": context.conversation_id,
|
||||
"agent_id": context.agent_id,
|
||||
"agent_config_snapshot_id": context.agent_config_snapshot_id,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _plugin_daemon_plugin_id(*, plugin_id: str, model_provider: str) -> str:
|
||||
"""Return the transport plugin id expected by plugin-daemon headers."""
|
||||
if plugin_id.count("/") == 1:
|
||||
return plugin_id
|
||||
if plugin_id:
|
||||
return ModelProviderID(plugin_id).plugin_id
|
||||
return ModelProviderID(model_provider).plugin_id
|
||||
|
||||
@staticmethod
|
||||
def _plugin_daemon_provider_name(model_provider: str) -> str:
|
||||
"""Return the provider name expected by plugin-daemon dispatch payloads."""
|
||||
return ModelProviderID(model_provider).provider_name
|
||||
|
||||
@staticmethod
|
||||
def _normalize_credentials(credentials: Mapping[str, Any]) -> dict[str, str | int | float | bool | None]:
|
||||
normalized: dict[str, str | int | float | bool | None] = {}
|
||||
for key, value in credentials.items():
|
||||
if isinstance(value, str | int | float | bool) or value is None:
|
||||
normalized[key] = value
|
||||
else:
|
||||
normalized[key] = str(value)
|
||||
return normalized
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentAppRuntimeBuildContext",
|
||||
"AgentAppRuntimeRequest",
|
||||
"AgentAppRuntimeRequestBuildError",
|
||||
"AgentAppRuntimeRequestBuilder",
|
||||
]
|
||||
146
api/core/app/apps/agent_app/session_store.py
Normal file
146
api/core/app/apps/agent_app/session_store.py
Normal file
@ -0,0 +1,146 @@
|
||||
"""Conversation-keyed Agent backend session store for the Agent App type.
|
||||
|
||||
Shares the unified ``agent_runtime_sessions`` table with the workflow Agent
|
||||
Node store, but owns rows with ``owner_type = conversation``: one Agent App
|
||||
conversation maps to one Agent session, so multi-turn chat re-enters the same
|
||||
``session_snapshot``. Cross-conversation memory (PRD Global / Per app) is a
|
||||
phase-2 concern and not modeled here.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.agent import (
|
||||
AgentRuntimeSession,
|
||||
AgentRuntimeSessionOwnerType,
|
||||
AgentRuntimeSessionStatus,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AgentAppSessionScope:
|
||||
"""Identity of one Agent App conversation session."""
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
agent_id: str
|
||||
agent_config_snapshot_id: str
|
||||
|
||||
|
||||
class AgentAppRuntimeSessionStore:
|
||||
"""Persists Agent backend session snapshots for Agent App conversations."""
|
||||
|
||||
def load_active_snapshot(self, scope: AgentAppSessionScope) -> CompositorSessionSnapshot | None:
|
||||
with session_factory.create_session() as session:
|
||||
row = session.scalar(self._active_stmt(scope))
|
||||
if row is None:
|
||||
return None
|
||||
return CompositorSessionSnapshot.model_validate_json(row.session_snapshot)
|
||||
|
||||
def load_active_snapshot_for_conversation(
|
||||
self, *, tenant_id: str, app_id: str, conversation_id: str
|
||||
) -> CompositorSessionSnapshot | None:
|
||||
"""Load a conversation's active snapshot without the agent/config scope.
|
||||
|
||||
One Agent App conversation maps to one active session, so the workspace
|
||||
inspector can resolve it from the conversation alone (it does not know
|
||||
which agent config version a past turn ran under).
|
||||
"""
|
||||
stmt = (
|
||||
select(AgentRuntimeSession)
|
||||
.where(
|
||||
AgentRuntimeSession.owner_type == AgentRuntimeSessionOwnerType.CONVERSATION,
|
||||
AgentRuntimeSession.tenant_id == tenant_id,
|
||||
AgentRuntimeSession.app_id == app_id,
|
||||
AgentRuntimeSession.conversation_id == conversation_id,
|
||||
AgentRuntimeSession.status == AgentRuntimeSessionStatus.ACTIVE,
|
||||
)
|
||||
.order_by(AgentRuntimeSession.updated_at.desc())
|
||||
)
|
||||
with session_factory.create_session() as session:
|
||||
row = session.scalar(stmt)
|
||||
if row is None:
|
||||
return None
|
||||
return CompositorSessionSnapshot.model_validate_json(row.session_snapshot)
|
||||
|
||||
def save_active_snapshot(
|
||||
self,
|
||||
*,
|
||||
scope: AgentAppSessionScope,
|
||||
backend_run_id: str,
|
||||
snapshot: CompositorSessionSnapshot | None,
|
||||
) -> None:
|
||||
if snapshot is None:
|
||||
return
|
||||
snapshot_json = snapshot.model_dump_json()
|
||||
with session_factory.create_session() as session:
|
||||
row = session.scalar(self._scope_stmt(scope))
|
||||
if row is None:
|
||||
row = AgentRuntimeSession(
|
||||
tenant_id=scope.tenant_id,
|
||||
app_id=scope.app_id,
|
||||
owner_type=AgentRuntimeSessionOwnerType.CONVERSATION,
|
||||
agent_id=scope.agent_id,
|
||||
agent_config_snapshot_id=scope.agent_config_snapshot_id,
|
||||
conversation_id=scope.conversation_id,
|
||||
backend_run_id=backend_run_id,
|
||||
session_snapshot=snapshot_json,
|
||||
composition_layer_specs="[]",
|
||||
status=AgentRuntimeSessionStatus.ACTIVE,
|
||||
)
|
||||
session.add(row)
|
||||
else:
|
||||
row.backend_run_id = backend_run_id
|
||||
row.session_snapshot = snapshot_json
|
||||
row.status = AgentRuntimeSessionStatus.ACTIVE
|
||||
row.cleaned_at = None
|
||||
session.flush()
|
||||
other_rows = session.scalars(
|
||||
select(AgentRuntimeSession).where(
|
||||
AgentRuntimeSession.owner_type == AgentRuntimeSessionOwnerType.CONVERSATION,
|
||||
AgentRuntimeSession.tenant_id == scope.tenant_id,
|
||||
AgentRuntimeSession.app_id == scope.app_id,
|
||||
AgentRuntimeSession.conversation_id == scope.conversation_id,
|
||||
AgentRuntimeSession.status == AgentRuntimeSessionStatus.ACTIVE,
|
||||
AgentRuntimeSession.id != row.id,
|
||||
)
|
||||
).all()
|
||||
for other_row in other_rows:
|
||||
other_row.status = AgentRuntimeSessionStatus.CLEANED
|
||||
other_row.cleaned_at = naive_utc_now()
|
||||
session.commit()
|
||||
|
||||
def mark_cleaned(self, *, scope: AgentAppSessionScope, backend_run_id: str | None = None) -> None:
|
||||
with session_factory.create_session() as session:
|
||||
row = session.scalar(self._active_stmt(scope))
|
||||
if row is None:
|
||||
return
|
||||
if backend_run_id is not None:
|
||||
row.backend_run_id = backend_run_id
|
||||
row.status = AgentRuntimeSessionStatus.CLEANED
|
||||
row.cleaned_at = naive_utc_now()
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def _scope_stmt(scope: AgentAppSessionScope):
|
||||
return select(AgentRuntimeSession).where(
|
||||
AgentRuntimeSession.owner_type == AgentRuntimeSessionOwnerType.CONVERSATION,
|
||||
AgentRuntimeSession.tenant_id == scope.tenant_id,
|
||||
AgentRuntimeSession.conversation_id == scope.conversation_id,
|
||||
AgentRuntimeSession.agent_id == scope.agent_id,
|
||||
AgentRuntimeSession.agent_config_snapshot_id == scope.agent_config_snapshot_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _active_stmt(cls, scope: AgentAppSessionScope):
|
||||
return cls._scope_stmt(scope).where(AgentRuntimeSession.status == AgentRuntimeSessionStatus.ACTIVE)
|
||||
|
||||
|
||||
__all__ = ["AgentAppRuntimeSessionStore", "AgentAppSessionScope"]
|
||||
@ -20,6 +20,7 @@ from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
from core.helper.trace_id_helper import extract_trace_session_id_from_args
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@ -96,7 +97,10 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args.get("auto_generate_name", True),
|
||||
**extract_trace_session_id_from_args(args),
|
||||
}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
|
||||
@ -134,6 +134,10 @@ class AppQueueManager(ABC):
|
||||
self._check_for_sqlalchemy_models(event.model_dump())
|
||||
self._publish(event, pub_from)
|
||||
|
||||
def is_stopped(self) -> bool:
|
||||
"""Return whether the current task has been manually stopped."""
|
||||
return self._is_stopped()
|
||||
|
||||
@abstractmethod
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
|
||||
@ -20,6 +20,7 @@ from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from core.helper.trace_id_helper import extract_trace_session_id_from_args
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@ -89,7 +90,10 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args.get("auto_generate_name", True),
|
||||
**extract_trace_session_id_from_args(args),
|
||||
}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
|
||||
@ -52,15 +52,11 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
|
||||
|
||||
# Maps the entry surface a workflow was invoked from to the HITL surface that
|
||||
# its resume tokens must be filtered for. Surfaces not in this map fall back to
|
||||
# the general priority ordering (typically CONSOLE > BACKSTAGE).
|
||||
_INVOKE_FROM_TO_HITL_SURFACE: Mapping[InvokeFrom, HumanInputSurface] = {
|
||||
InvokeFrom.SERVICE_API: HumanInputSurface.SERVICE_API,
|
||||
InvokeFrom.OPENAPI: HumanInputSurface.OPENAPI,
|
||||
}
|
||||
from core.workflow.human_input_policy import (
|
||||
HumanInputSurface,
|
||||
enrich_human_input_pause_reasons,
|
||||
resolve_human_input_pause_reason_inputs,
|
||||
)
|
||||
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
@ -83,6 +79,14 @@ from models.human_input import HumanInputForm
|
||||
from models.workflow import WorkflowRun
|
||||
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
|
||||
|
||||
# Maps the entry surface a workflow was invoked from to the HITL surface that
|
||||
# its resume tokens must be filtered for. Surfaces not in this map fall back to
|
||||
# the general priority ordering (typically CONSOLE > BACKSTAGE).
|
||||
_INVOKE_FROM_TO_HITL_SURFACE: Mapping[InvokeFrom, HumanInputSurface] = {
|
||||
InvokeFrom.SERVICE_API: HumanInputSurface.SERVICE_API,
|
||||
InvokeFrom.OPENAPI: HumanInputSurface.OPENAPI,
|
||||
}
|
||||
|
||||
NodeExecutionId = NewType("NodeExecutionId", str)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -327,8 +331,13 @@ class WorkflowResponseConverter:
|
||||
encoded_outputs = self._encode_outputs(event.outputs) or {}
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
|
||||
encoded_outputs = {}
|
||||
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
|
||||
human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)]
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
resolved_reasons = resolve_human_input_pause_reason_inputs(
|
||||
event.reasons,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
pause_reasons = [reason.model_dump(mode="json") for reason in resolved_reasons]
|
||||
human_input_form_ids = [reason.form_id for reason in resolved_reasons if isinstance(reason, HumanInputRequired)]
|
||||
expiration_times_by_form_id: dict[str, datetime] = {}
|
||||
display_in_ui_by_form_id: dict[str, bool] = {}
|
||||
form_token_by_form_id: dict[str, str] = {}
|
||||
@ -365,7 +374,7 @@ class WorkflowResponseConverter:
|
||||
|
||||
responses: list[StreamResponse] = []
|
||||
|
||||
for reason in event.reasons:
|
||||
for reason in resolved_reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
expiration_time = expiration_times_by_form_id.get(reason.form_id)
|
||||
if expiration_time is None:
|
||||
@ -413,17 +422,19 @@ class WorkflowResponseConverter:
|
||||
self, *, event: QueueHumanInputFormFilledEvent, task_id: str
|
||||
) -> HumanInputFormFilledResponse:
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
return HumanInputFormFilledResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=HumanInputFormFilledResponse.Data(
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
),
|
||||
data = HumanInputFormFilledResponse.Data(
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
)
|
||||
if event.submitted_data is not None:
|
||||
runtime_type_converter = WorkflowRuntimeTypeConverter()
|
||||
|
||||
data.submitted_data = runtime_type_converter.value_to_json_encodable_recursive(event.submitted_data)
|
||||
|
||||
return HumanInputFormFilledResponse(task_id=task_id, workflow_run_id=run_id, data=data)
|
||||
|
||||
def human_input_form_timeout_to_stream_response(
|
||||
self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str
|
||||
|
||||
@ -20,6 +20,7 @@ from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||
from core.helper.trace_id_helper import extract_trace_session_id_from_args
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@ -148,7 +149,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras={},
|
||||
extras={
|
||||
**extract_trace_session_id_from_args(args),
|
||||
},
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
@ -32,7 +32,11 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args, extract_parent_trace_context_from_args
|
||||
from core.helper.trace_id_helper import (
|
||||
extract_external_trace_id_from_args,
|
||||
extract_parent_trace_context_from_args,
|
||||
extract_trace_session_id_from_args,
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
@ -57,26 +61,13 @@ SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_trace_session_id_from_debug_args(args: Mapping[str, Any] | Any) -> dict[str, str]:
|
||||
if isinstance(args, Mapping):
|
||||
return extract_trace_session_id_from_args(args)
|
||||
return extract_trace_session_id_from_args({"trace_session_id": getattr(args, "trace_session_id", None)})
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@staticmethod
|
||||
def _ensure_snippet_start_node_in_worker(*, session: Session, workflow: Workflow) -> Workflow:
|
||||
"""Re-apply snippet virtual Start injection after worker reloads workflow from DB."""
|
||||
if workflow.kind_or_standard != "snippet":
|
||||
return workflow
|
||||
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
|
||||
snippet = session.scalar(
|
||||
select(CustomizedSnippet).where(
|
||||
CustomizedSnippet.id == workflow.app_id,
|
||||
CustomizedSnippet.tenant_id == workflow.tenant_id,
|
||||
)
|
||||
)
|
||||
if snippet is None:
|
||||
return workflow
|
||||
return SnippetGenerateService.ensure_start_node_for_worker(workflow, snippet)
|
||||
|
||||
@staticmethod
|
||||
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
|
||||
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
|
||||
@ -186,6 +177,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
extras = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
**extract_parent_trace_context_from_args(args),
|
||||
**extract_trace_session_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(workflow_run_id or uuid.uuid4())
|
||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||
@ -429,7 +421,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
extras={
|
||||
"auto_generate_conversation_name": False,
|
||||
**_extract_trace_session_id_from_debug_args(args),
|
||||
},
|
||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
@ -515,7 +510,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
extras={
|
||||
"auto_generate_conversation_name": False,
|
||||
**_extract_trace_session_id_from_debug_args(args),
|
||||
},
|
||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs or {}),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
@ -594,8 +592,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
if workflow is None:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
workflow = self._ensure_snippet_start_node_in_worker(session=session, workflow=workflow)
|
||||
|
||||
# Determine system_user_id based on invocation source
|
||||
is_external_api_call = application_generate_entity.invoke_from in {
|
||||
InvokeFrom.WEB_APP,
|
||||
|
||||
@ -87,6 +87,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
root_node_id=self._root_node_id,
|
||||
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
|
||||
)
|
||||
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
@ -94,6 +95,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
@ -128,6 +130,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
root_node_id=root_node_id,
|
||||
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
|
||||
@ -118,6 +118,7 @@ class WorkflowBasedAppRunner:
|
||||
tenant_id: str = "",
|
||||
user_id: str = "",
|
||||
root_node_id: str | None = None,
|
||||
trace_session_id: str | None = None,
|
||||
) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
@ -138,6 +139,7 @@ class WorkflowBasedAppRunner:
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
trace_session_id=trace_session_id,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id=workflow_id,
|
||||
@ -171,6 +173,7 @@ class WorkflowBasedAppRunner:
|
||||
single_loop_run: Any | None = None,
|
||||
*,
|
||||
user_id: str,
|
||||
trace_session_id: str | None = None,
|
||||
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
|
||||
"""
|
||||
Prepare graph, variable pool, and runtime state for single node execution
|
||||
@ -208,6 +211,7 @@ class WorkflowBasedAppRunner:
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
user_id=user_id,
|
||||
trace_session_id=trace_session_id,
|
||||
)
|
||||
elif single_loop_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
|
||||
@ -218,6 +222,7 @@ class WorkflowBasedAppRunner:
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
user_id=user_id,
|
||||
trace_session_id=trace_session_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
|
||||
@ -236,6 +241,7 @@ class WorkflowBasedAppRunner:
|
||||
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
|
||||
*,
|
||||
user_id: str = "",
|
||||
trace_session_id: str | None = None,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get graph and variable pool for single node execution (iteration or loop).
|
||||
@ -301,6 +307,7 @@ class WorkflowBasedAppRunner:
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
trace_session_id=trace_session_id,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id=workflow.id,
|
||||
@ -435,6 +442,7 @@ class WorkflowBasedAppRunner:
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
submitted_data=event.submitted_data,
|
||||
)
|
||||
)
|
||||
case NodeRunHumanInputFormTimeoutEvent():
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user