mirror of
https://github.com/langgenius/dify.git
synced 2026-06-04 07:56:20 +08:00
Compare commits
132 Commits
build/cli
...
feat/help-
| Author | SHA1 | Date | |
|---|---|---|---|
| 77a6bc28a7 | |||
| 5eb26785ba | |||
| 1f2492d635 | |||
| 23def1ef3f | |||
| 831a53e088 | |||
| 60295b66c3 | |||
| 9727f7a94a | |||
| 2d44910bf8 | |||
| 99174f90c8 | |||
| db3af89a4c | |||
| a3f3a8d169 | |||
| 1c0a3b3351 | |||
| 50b67efacb | |||
| 919f56ebee | |||
| 10d9a6bc49 | |||
| 06d886d4ba | |||
| b77a242b13 | |||
| 9f01fbf770 | |||
| 5c7f05bd10 | |||
| 02e1a60cde | |||
| 57b573d02b | |||
| 9de40e8f21 | |||
| cad0942f4d | |||
| cb9b1b593e | |||
| 2a8bdc2373 | |||
| ee6a07d13c | |||
| 2d6c9300e3 | |||
| d6b4c800c2 | |||
| 1b37635f92 | |||
| 86af36429d | |||
| b96ea94505 | |||
| d649cccda0 | |||
| 5cbbd78f38 | |||
| 5a0ad4ecd9 | |||
| 1e76b9e1b8 | |||
| 1b972c4e09 | |||
| 7968d2c3c8 | |||
| 7507e9ba67 | |||
| ca31762e26 | |||
| f591da7865 | |||
| f19679b217 | |||
| b682591c7a | |||
| 8f6b59feff | |||
| 99833f65d8 | |||
| 696fc5c213 | |||
| eae44cfecb | |||
| dea4e66456 | |||
| 3cd0da303a | |||
| 888483a2f8 | |||
| 7056985f72 | |||
| 6ce61eae59 | |||
| 079af312c6 | |||
| 0da13dfe4d | |||
| 1ff4d75084 | |||
| e35d23c3cb | |||
| e530e84772 | |||
| 2257a4f1ef | |||
| f465dc5090 | |||
| 5c1cfe6ada | |||
| 8d401d84c7 | |||
| b74287c2ab | |||
| c64d3e98c4 | |||
| a3265f722e | |||
| 5658065b97 | |||
| 8fc2807194 | |||
| fc7716704d | |||
| 71ffaacb58 | |||
| cfc1cf2b8c | |||
| 055d9b9f0a | |||
| 21711bebeb | |||
| becccbf288 | |||
| 86497045c9 | |||
| 687a177b24 | |||
| 4a6d278354 | |||
| 7d69302e9f | |||
| bcd573e560 | |||
| 07c0c4e7b1 | |||
| a8a2ca7b98 | |||
| de47d43b65 | |||
| 240912cef5 | |||
| 72e040ead3 | |||
| c0ee821d45 | |||
| c7c3296572 | |||
| e7be04fd58 | |||
| df6b5be50a | |||
| 8e5f09091b | |||
| 0a3005701f | |||
| d8571ce965 | |||
| f241ae25be | |||
| c6474a2a8b | |||
| 480d05bc48 | |||
| f75725ccd9 | |||
| 2fe8c48255 | |||
| ec5404cc9d | |||
| 20f62b9919 | |||
| 04f5555580 | |||
| 129af96c23 | |||
| df40960f5d | |||
| 599960024d | |||
| 6805d9bfc0 | |||
| 928f888ef5 | |||
| f46c03460e | |||
| 0b60338ad5 | |||
| 91ac465982 | |||
| 9490d63c50 | |||
| ae538ced47 | |||
| 487249728b | |||
| 372a2e3e9c | |||
| 4939a9c33d | |||
| b6f92f1dc4 | |||
| ce276573a8 | |||
| 5070cc9668 | |||
| a392a72960 | |||
| 30270b5c30 | |||
| 24715a9570 | |||
| c530a5d272 | |||
| 418ee7398e | |||
| 78f40c0d25 | |||
| 2cc567c6a3 | |||
| a180ab19e4 | |||
| 13eaa436e7 | |||
| 3596d12e4c | |||
| e8de10a3b5 | |||
| f5ab5e7eb3 | |||
| 0c40e1c2a0 | |||
| c29d76757e | |||
| 91c1d3ad81 | |||
| 57b02e341c | |||
| b94ff65e9f | |||
| 678260e34e | |||
| 739e34d08a | |||
| 825fb9cb89 |
@ -1 +0,0 @@
|
||||
../../.agents/skills/frontend-query-mutation
|
||||
1
.claude/skills/how-to-write-component
Symbolic link
1
.claude/skills/how-to-write-component
Symbolic link
@ -0,0 +1 @@
|
||||
../../.agents/skills/how-to-write-component
|
||||
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@ -166,6 +166,7 @@
|
||||
|
||||
# Frontend - App - API Documentation
|
||||
/web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/develop/template/*.mdx @JzoNgKVO @iamjoel @RiskeyL
|
||||
|
||||
# Frontend - App - Logs and Annotations
|
||||
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||
|
||||
20
.github/workflows/autofix.yml
vendored
20
.github/workflows/autofix.yml
vendored
@ -51,6 +51,15 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
- name: Check dify-agent inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: dify-agent-changes
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
dify-agent/**/*.py
|
||||
dify-agent/pyproject.toml
|
||||
dify-agent/uv.lock
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
|
||||
with:
|
||||
@ -76,6 +85,17 @@ jobs:
|
||||
# Format code
|
||||
uv run ruff format ..
|
||||
|
||||
- if: github.event_name != 'merge_group' && steps.dify-agent-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd dify-agent
|
||||
uv sync --dev
|
||||
# fmt first to avoid line too long
|
||||
uv run ruff format .
|
||||
# Fix lint errors
|
||||
uv run ruff check --fix .
|
||||
# Format code
|
||||
uv run ruff format .
|
||||
|
||||
- name: count migration progress
|
||||
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
name: Deploy Agent Dev
|
||||
name: Deploy SaaS
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@ -7,7 +7,7 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/agent-dev"
|
||||
- "deploy/saas"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@ -16,13 +16,13 @@ jobs:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/agent-dev'
|
||||
github.event.workflow_run.head_branch == 'deploy/saas'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
|
||||
with:
|
||||
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
|
||||
host: ${{ secrets.SAAS_DEV_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
|
||||
${{ vars.SSH_SCRIPT_SAAS_DEV || secrets.SSH_SCRIPT_SAAS_DEV }}
|
||||
63
.github/workflows/style.yml
vendored
63
.github/workflows/style.yml
vendored
@ -95,6 +95,51 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
NODE_OPTIONS: --max-old-space-size=4096
|
||||
run: vp run lint:tss
|
||||
|
||||
- name: Web dead code check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: vp run knip
|
||||
|
||||
ts-common-style:
|
||||
name: TS Common
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
checks: write
|
||||
pull-requests: read
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
cli/**
|
||||
e2e/**
|
||||
sdks/nodejs-client/**
|
||||
packages/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.nvmrc
|
||||
eslint.config.mjs
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
- name: Setup web environment
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Restore ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: eslint-cache-restore
|
||||
@ -105,28 +150,14 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
|
||||
- name: Web style check
|
||||
- name: Style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: .
|
||||
run: vp run lint:ci
|
||||
|
||||
- name: Web tsslint
|
||||
- name: Type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
env:
|
||||
NODE_OPTIONS: --max-old-space-size=4096
|
||||
run: vp run lint:tss
|
||||
|
||||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: .
|
||||
run: vp run type-check
|
||||
|
||||
- name: Web dead code check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: vp run knip
|
||||
|
||||
- name: Save ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -259,3 +259,6 @@ scripts/stress-test/reports/
|
||||
.qoder/*
|
||||
.context/
|
||||
.eslintcache
|
||||
|
||||
# Vitest local reports
|
||||
web/.vitest-reports/
|
||||
|
||||
27
SECURITY.md
Normal file
27
SECURITY.md
Normal file
@ -0,0 +1,27 @@
|
||||
# Security Policy
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
If you believe you have found a security vulnerability in Dify, please report it privately through GitHub Security Advisories:
|
||||
|
||||
https://github.com/langgenius/dify/security/advisories/new
|
||||
|
||||
Please do not report security vulnerabilities through public GitHub issues, discussions, or pull requests.
|
||||
|
||||
When submitting a report, include as much relevant information as you can safely provide, such as:
|
||||
|
||||
- A description of the vulnerability
|
||||
- Steps to reproduce, if safe to share privately
|
||||
- Affected components, versions, or configurations
|
||||
- Potential impact
|
||||
- Any suggested mitigation or fix, if available
|
||||
|
||||
The maintainers will review reports submitted through GitHub Security Advisories and coordinate follow-up there.
|
||||
|
||||
## Public Disclosure
|
||||
|
||||
Please avoid publicly disclosing details of a vulnerability until it has been reviewed and, where appropriate, a fix or mitigation has been made available.
|
||||
|
||||
## Security Updates
|
||||
|
||||
Security fixes may be released through normal project releases or other appropriate channels. Users are encouraged to keep Dify deployments up to date.
|
||||
@ -27,7 +27,7 @@ COPY api/providers ./providers
|
||||
COPY dify-agent/pyproject.toml dify-agent/README.md /app/dify-agent/
|
||||
COPY dify-agent/src /app/dify-agent/src
|
||||
# Trust the checked-in lock during image builds; local path sources are copied from the repository context.
|
||||
RUN uv sync --frozen --no-dev
|
||||
RUN uv sync --frozen --no-dev --no-editable
|
||||
|
||||
# production stage
|
||||
FROM base AS production
|
||||
|
||||
@ -34,6 +34,7 @@ from clients.agent_backend.request_builder import (
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID,
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID,
|
||||
AgentBackendAgentAppRunInput,
|
||||
AgentBackendModelConfig,
|
||||
AgentBackendOutputConfig,
|
||||
AgentBackendRunRequestBuilder,
|
||||
@ -49,6 +50,7 @@ __all__ = [
|
||||
"DIFY_PLUGIN_TOOLS_LAYER_ID",
|
||||
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
|
||||
"WORKFLOW_USER_PROMPT_LAYER_ID",
|
||||
"AgentBackendAgentAppRunInput",
|
||||
"AgentBackendError",
|
||||
"AgentBackendHTTPError",
|
||||
"AgentBackendInternalEvent",
|
||||
|
||||
@ -45,6 +45,7 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
|
||||
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
|
||||
AGENT_APP_USER_PROMPT_LAYER_ID = "agent_app_user_prompt"
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID = "tools"
|
||||
|
||||
@ -181,9 +182,138 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
class AgentBackendAgentAppRunInput(BaseModel):
|
||||
"""Inputs to build one Agent App conversation-turn run request.
|
||||
|
||||
Unlike the workflow-node input there is no workflow-node-job prompt and no
|
||||
previous-node context: the user prompt is the chat message, and multi-turn
|
||||
continuity comes from ``session_snapshot`` + the history layer keyed by the
|
||||
conversation.
|
||||
"""
|
||||
|
||||
model: AgentBackendModelConfig
|
||||
execution_context: DifyExecutionContextLayerConfig
|
||||
user_prompt: str
|
||||
agent_soul_prompt: str | None = None
|
||||
purpose: RunPurpose = "agent_app"
|
||||
idempotency_key: str | None = None
|
||||
output: AgentBackendOutputConfig | None = None
|
||||
tools: DifyPluginToolsLayerConfig | None = None
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
include_history: bool = True
|
||||
suspend_on_exit: bool = True
|
||||
metadata: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
|
||||
@field_validator("user_prompt")
|
||||
@classmethod
|
||||
def _reject_blank_prompt(cls, value: str) -> str:
|
||||
if not value.strip():
|
||||
raise ValueError("prompt must not be blank")
|
||||
return value
|
||||
|
||||
|
||||
class AgentBackendRunRequestBuilder:
|
||||
"""Converts API product state into the public ``dify-agent`` run protocol."""
|
||||
|
||||
def build_for_agent_app(self, run_input: AgentBackendAgentAppRunInput) -> CreateRunRequest:
|
||||
"""Build an Agent App conversation-turn run request.
|
||||
|
||||
Layer graph: optional Agent Soul system prompt → user prompt →
|
||||
execution context → optional history (multi-turn) → LLM → optional
|
||||
plugin tools → optional structured output. Mirrors the workflow-node
|
||||
layer ordering minus the workflow-job / previous-node prompt.
|
||||
"""
|
||||
layers: list[RunLayerSpec] = []
|
||||
if run_input.agent_soul_prompt:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=AGENT_SOUL_PROMPT_LAYER_ID,
|
||||
type=PLAIN_PROMPT_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_soul"},
|
||||
config=PromptLayerConfig(prefix=run_input.agent_soul_prompt),
|
||||
)
|
||||
)
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
RunLayerSpec(
|
||||
name=AGENT_APP_USER_PROMPT_LAYER_ID,
|
||||
type=PLAIN_PROMPT_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_app_user_prompt"},
|
||||
config=PromptLayerConfig(user=run_input.user_prompt),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
|
||||
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.execution_context,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if run_input.include_history:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_session_history"},
|
||||
)
|
||||
)
|
||||
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_MODEL_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=DifyPluginLLMLayerConfig(
|
||||
plugin_id=run_input.model.plugin_id,
|
||||
model_provider=run_input.model.model_provider,
|
||||
model=run_input.model.model,
|
||||
credentials=run_input.model.credentials,
|
||||
model_settings=run_input.model.model_settings or None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.tools is not None and run_input.tools.tools:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_PLUGIN_TOOLS_LAYER_ID,
|
||||
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.tools,
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.output is not None:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
type=DIFY_OUTPUT_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=DifyOutputLayerConfig(
|
||||
json_schema=run_input.output.json_schema,
|
||||
description=run_input.output.description,
|
||||
strict=run_input.output.strict,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return CreateRunRequest(
|
||||
composition=RunComposition(layers=layers),
|
||||
purpose=run_input.purpose,
|
||||
idempotency_key=run_input.idempotency_key,
|
||||
metadata=run_input.metadata,
|
||||
session_snapshot=run_input.session_snapshot,
|
||||
on_exit=LayerExitSignals(
|
||||
default=ExitIntent.SUSPEND if run_input.suspend_on_exit else ExitIntent.DELETE,
|
||||
),
|
||||
)
|
||||
|
||||
def build_cleanup_request(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -4,6 +4,12 @@ CLI command modules extracted from `commands.py`.
|
||||
|
||||
from .account import create_tenant, reset_email, reset_password
|
||||
from .data_migrate import data_migrate, legacy_model_types
|
||||
from .data_migration import (
|
||||
export_migration_data,
|
||||
export_migration_data_template,
|
||||
import_migration_data,
|
||||
migration_data_wizard,
|
||||
)
|
||||
from .plugin import (
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
@ -26,7 +32,12 @@ from .retention import (
|
||||
restore_workflow_runs,
|
||||
)
|
||||
from .storage import clear_orphaned_file_records, file_usage, migrate_oss, remove_orphaned_files_on_storage
|
||||
from .system import convert_to_agent_apps, fix_app_site_missing, reset_encrypt_key_pair, upgrade_db
|
||||
from .system import (
|
||||
convert_to_agent_apps,
|
||||
fix_app_site_missing,
|
||||
reset_encrypt_key_pair,
|
||||
upgrade_db,
|
||||
)
|
||||
from .vector import (
|
||||
add_qdrant_index,
|
||||
migrate_annotation_vector_database,
|
||||
@ -48,10 +59,13 @@ __all__ = [
|
||||
"data_migrate",
|
||||
"delete_archived_workflow_runs",
|
||||
"export_app_messages",
|
||||
"export_migration_data",
|
||||
"export_migration_data_template",
|
||||
"extract_plugins",
|
||||
"extract_unique_plugins",
|
||||
"file_usage",
|
||||
"fix_app_site_missing",
|
||||
"import_migration_data",
|
||||
"install_plugins",
|
||||
"install_rag_pipeline_plugins",
|
||||
"legacy_model_types",
|
||||
@ -59,6 +73,7 @@ __all__ = [
|
||||
"migrate_data_for_plugin",
|
||||
"migrate_knowledge_vector_database",
|
||||
"migrate_oss",
|
||||
"migration_data_wizard",
|
||||
"old_metadata_migration",
|
||||
"remove_orphaned_files_on_storage",
|
||||
"reset_email",
|
||||
|
||||
754
api/commands/data_migration.py
Normal file
754
api/commands/data_migration.py
Normal file
@ -0,0 +1,754 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
import yaml
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.model import App
|
||||
from models.tools import ApiToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.data_migration.dependency_discovery_service import DependencyDiscoveryService
|
||||
from services.data_migration.entities import (
|
||||
DependencyKind,
|
||||
ImportOptions,
|
||||
MigrationDataError,
|
||||
ReportContext,
|
||||
ResourceReportItem,
|
||||
)
|
||||
from services.data_migration.export_service import ExportConfigParser, MigrationExportService
|
||||
from services.data_migration.import_service import ImportRequest, MigrationImportService
|
||||
from services.data_migration.package_service import MigrationPackageService
|
||||
from services.data_migration.report_service import MigrationReportService
|
||||
|
||||
ID_STRATEGY_CHOICES = ["preserve-id", "generate-new-id"]
|
||||
CONFLICT_STRATEGY_CHOICES = ["fail", "skip", "update"]
|
||||
SUPPORTED_WIZARD_APP_MODES = ["workflow", "advanced-chat"]
|
||||
WizardToolMap = dict[str, dict[str, str | None]]
|
||||
WizardToolSelection = dict[str, list[str]]
|
||||
|
||||
|
||||
def _scripted_export_template() -> dict[str, Any]:
|
||||
return {
|
||||
"source_tenant": {
|
||||
"mode": "single",
|
||||
"id": "",
|
||||
"name": "admin's Workspace",
|
||||
},
|
||||
"apps": {
|
||||
"modes": ["workflow", "advanced-chat"],
|
||||
"ids": [],
|
||||
"all": True,
|
||||
},
|
||||
"include_referenced_tools": True,
|
||||
"additional_tools": {
|
||||
"api_tools": [],
|
||||
"workflow_tools": [],
|
||||
"mcp_tools": [],
|
||||
},
|
||||
"include_secrets": False,
|
||||
"import_options": {
|
||||
"create_app_api_token_on_import": False,
|
||||
"id_strategy": "preserve-id",
|
||||
"conflict_strategy": "fail",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@click.command("app-migration-template", help="Print or write a scripted export config JSON template.")
|
||||
@click.option(
|
||||
"--output",
|
||||
"output_file",
|
||||
required=False,
|
||||
type=click.Path(dir_okay=False),
|
||||
help="Path to write the export config JSON template. Prints to stdout when omitted.",
|
||||
)
|
||||
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite output if it already exists.")
|
||||
def export_migration_data_template(output_file: str | None, overwrite: bool) -> None:
|
||||
template_json = json.dumps(_scripted_export_template(), indent=2, ensure_ascii=False) + "\n"
|
||||
if output_file is None:
|
||||
click.echo(template_json, nl=False)
|
||||
return
|
||||
path = Path(output_file)
|
||||
if path.exists() and not overwrite:
|
||||
raise click.ClickException(f"Output file already exists: {output_file}")
|
||||
path.write_text(template_json)
|
||||
click.echo(click.style(f"Output written to {output_file}", fg="green"))
|
||||
|
||||
|
||||
@click.command("export-app-migration", help="Export workflow migration data to a versioned JSON package.")
|
||||
@click.option(
|
||||
"--input",
|
||||
"input_file",
|
||||
required=False,
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
help="Path to export config JSON.",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
"output_file",
|
||||
required=False,
|
||||
type=click.Path(dir_okay=False),
|
||||
help="Path to migration package JSON.",
|
||||
)
|
||||
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite output if it already exists.")
|
||||
def export_migration_data(input_file: str | None, output_file: str | None, overwrite: bool) -> None:
|
||||
try:
|
||||
_require_options(("--input", input_file), ("--output", output_file))
|
||||
assert input_file is not None
|
||||
assert output_file is not None
|
||||
raw_config = _load_json_object(input_file, "Export config")
|
||||
selection = ExportConfigParser().parse(raw_config)
|
||||
result = MigrationExportService().export(selection)
|
||||
MigrationPackageService().save_package(result.package, output_file, overwrite=overwrite)
|
||||
click.echo(click.style(f"Output written to {output_file}", fg="green"))
|
||||
_render_report(result.report_items, context=_with_output_path(result.report_context, output_file))
|
||||
except MigrationDataError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
|
||||
|
||||
@click.command("import-app-migration", help="Import a versioned migration data package.")
|
||||
@click.option(
|
||||
"--input",
|
||||
"input_file",
|
||||
required=False,
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
help="Path to migration package JSON.",
|
||||
)
|
||||
@click.option("--target-tenant", default=None, help="Target tenant/workspace name. Overrides package metadata.")
|
||||
@click.option("--operator-email", default=None, help="Operator account email in the target tenant.")
|
||||
@click.option(
|
||||
"--id-strategy",
|
||||
default=None,
|
||||
type=click.Choice(ID_STRATEGY_CHOICES),
|
||||
help="Override package ID strategy.",
|
||||
)
|
||||
@click.option(
|
||||
"--conflict-strategy",
|
||||
default=None,
|
||||
type=click.Choice(CONFLICT_STRATEGY_CHOICES),
|
||||
help="Override package conflict strategy.",
|
||||
)
|
||||
@click.option(
|
||||
"--create-app-api-token-on-import/--no-create-app-api-token-on-import",
|
||||
default=None,
|
||||
help="Override package app API token creation behavior.",
|
||||
)
|
||||
def import_migration_data(
|
||||
input_file: str | None,
|
||||
target_tenant: str | None,
|
||||
operator_email: str | None,
|
||||
id_strategy: str | None,
|
||||
conflict_strategy: str | None,
|
||||
create_app_api_token_on_import: bool | None,
|
||||
) -> None:
|
||||
try:
|
||||
_require_options(("--input", input_file))
|
||||
assert input_file is not None
|
||||
package = MigrationPackageService().load_package(input_file)
|
||||
result = MigrationImportService().import_package(
|
||||
ImportRequest(
|
||||
package=package,
|
||||
cli_target_tenant=target_tenant,
|
||||
operator_email=operator_email,
|
||||
options_override=_build_options_override(
|
||||
package.metadata.import_options,
|
||||
id_strategy=id_strategy,
|
||||
conflict_strategy=conflict_strategy,
|
||||
create_app_api_token_on_import=create_app_api_token_on_import,
|
||||
),
|
||||
)
|
||||
)
|
||||
_render_report(result.report_items, context=result.report_context)
|
||||
except MigrationDataError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
|
||||
|
||||
def parse_index_selection(raw: str, values: list[str]) -> list[str]:
|
||||
normalized = raw.strip().lower()
|
||||
if normalized == "all":
|
||||
return values
|
||||
|
||||
selected: list[str] = []
|
||||
for part in raw.split(","):
|
||||
stripped = part.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
try:
|
||||
index = int(stripped)
|
||||
except ValueError as exc:
|
||||
raise click.ClickException(f"Selection must be 'all' or comma-separated numbers: {raw}") from exc
|
||||
if index < 1 or index > len(values):
|
||||
raise click.ClickException(f"Selection index out of range: {index}")
|
||||
selected.append(values[index - 1])
|
||||
return list(dict.fromkeys(selected))
|
||||
|
||||
|
||||
def _print_wizard_step(title: str) -> None:
|
||||
click.echo("")
|
||||
click.echo(f"==== {title} ====")
|
||||
|
||||
|
||||
def _print_wizard_substep(title: str) -> None:
|
||||
click.echo("")
|
||||
click.echo(f"-- {title} --")
|
||||
|
||||
|
||||
@click.command("app-migration-wizard", help="Interactively export workflow migration data.")
|
||||
def migration_data_wizard() -> None:
|
||||
try:
|
||||
tenant = _prompt_source_tenant()
|
||||
apps = _eligible_apps_for_tenant(tenant.id)
|
||||
app_ids = _prompt_app_ids(apps)
|
||||
_print_wizard_step("Referenced Tools")
|
||||
include_referenced_tools = click.confirm(
|
||||
"Automatically export tools referenced by selected apps? [y/n, default: y]",
|
||||
default=True,
|
||||
show_default=False,
|
||||
)
|
||||
auto_tools = _discover_auto_tools([app for app in apps if app.id in set(app_ids)], include_referenced_tools)
|
||||
auto_tools = _resolve_auto_tool_names(tenant.id, auto_tools)
|
||||
_print_auto_tools(auto_tools)
|
||||
additional_tools = _prompt_additional_tools(tenant.id, auto_tools)
|
||||
include_secrets, create_tokens, id_strategy, conflict_strategy = _prompt_import_options()
|
||||
_print_wizard_step("Output")
|
||||
output_file, overwrite = _prompt_output_file()
|
||||
|
||||
selection = ExportConfigParser().parse(
|
||||
{
|
||||
"source_tenant": {"mode": "single", "id": tenant.id, "name": tenant.name},
|
||||
"apps": {"ids": app_ids, "all": False},
|
||||
"include_referenced_tools": include_referenced_tools,
|
||||
"additional_tools": additional_tools,
|
||||
"include_secrets": include_secrets,
|
||||
"import_options": {
|
||||
"create_app_api_token_on_import": create_tokens,
|
||||
"id_strategy": id_strategy,
|
||||
"conflict_strategy": conflict_strategy,
|
||||
},
|
||||
}
|
||||
)
|
||||
_confirm_wizard_summary(
|
||||
tenant_name=tenant.name,
|
||||
app_names=[app.name for app in apps if app.id in set(app_ids)],
|
||||
auto_tools=auto_tools,
|
||||
additional_tools=additional_tools,
|
||||
manual_labels=_selected_tool_labels_for_tenant(tenant.id, additional_tools),
|
||||
include_referenced_tools=include_referenced_tools,
|
||||
include_secrets=include_secrets,
|
||||
create_tokens=create_tokens,
|
||||
id_strategy=id_strategy,
|
||||
conflict_strategy=conflict_strategy,
|
||||
output_file=output_file,
|
||||
)
|
||||
result = MigrationExportService().export(selection)
|
||||
MigrationPackageService().save_package(result.package, output_file, overwrite=overwrite)
|
||||
click.echo(click.style(f"Output written to {output_file}", fg="green"))
|
||||
_print_wizard_step("Report")
|
||||
_render_report(result.report_items, context=_with_output_path(result.report_context, output_file))
|
||||
except MigrationDataError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
|
||||
|
||||
def _load_json_object(path: str, label: str) -> dict[str, Any]:
|
||||
try:
|
||||
with Path(path).open(encoding="utf-8") as file:
|
||||
raw = json.load(file)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise MigrationDataError(f"{label} JSON is invalid: {exc.msg}") from exc
|
||||
if not isinstance(raw, dict):
|
||||
raise MigrationDataError(f"{label} JSON must be an object.")
|
||||
return raw
|
||||
|
||||
|
||||
def _require_options(*options: tuple[str, object | None]) -> None:
|
||||
missing_options = [name for name, value in options if value is None]
|
||||
if missing_options:
|
||||
raise click.UsageError(f"Missing option(s): {', '.join(missing_options)}.")
|
||||
|
||||
|
||||
def _build_options_override(
|
||||
package_options: ImportOptions,
|
||||
*,
|
||||
id_strategy: str | None,
|
||||
conflict_strategy: str | None,
|
||||
create_app_api_token_on_import: bool | None,
|
||||
) -> ImportOptions | None:
|
||||
if id_strategy is None and conflict_strategy is None and create_app_api_token_on_import is None:
|
||||
return None
|
||||
return ImportOptions.from_mapping(
|
||||
{
|
||||
"id_strategy": id_strategy or package_options.id_strategy,
|
||||
"conflict_strategy": conflict_strategy or package_options.conflict_strategy,
|
||||
"create_app_api_token_on_import": (
|
||||
create_app_api_token_on_import
|
||||
if create_app_api_token_on_import is not None
|
||||
else package_options.create_app_api_token_on_import
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _prompt_source_tenant() -> Tenant:
|
||||
tenants = list(db.session.scalars(sa.select(Tenant).order_by(Tenant.name.asc())).all())
|
||||
if not tenants:
|
||||
raise MigrationDataError("No tenants found.")
|
||||
|
||||
_print_wizard_step("Source Tenant")
|
||||
click.echo("Source tenants:")
|
||||
for index, tenant in enumerate(tenants, 1):
|
||||
click.echo(f"{index}. {tenant.name} ({tenant.id})")
|
||||
|
||||
tenant_index = click.prompt("Select one source tenant by number", type=int, default=1, show_default=True)
|
||||
if tenant_index < 1 or tenant_index > len(tenants):
|
||||
raise click.ClickException(f"Selection index out of range: {tenant_index}")
|
||||
return tenants[tenant_index - 1]
|
||||
|
||||
|
||||
def _eligible_apps_for_tenant(tenant_id: str) -> list[App]:
|
||||
return list(
|
||||
db.session.scalars(
|
||||
sa.select(App)
|
||||
.where(App.tenant_id == tenant_id, App.mode.in_(SUPPORTED_WIZARD_APP_MODES))
|
||||
.order_by(App.name.asc())
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def _prompt_app_ids(apps: list[App]) -> list[str]:
|
||||
if not apps:
|
||||
raise MigrationDataError("No workflow or advanced-chat apps found for the selected tenant.")
|
||||
|
||||
_print_wizard_step("App Selection")
|
||||
click.echo("Currently supported app types: workflow and chatflow.")
|
||||
click.echo("Workflow/chatflow apps:")
|
||||
for index, app in enumerate(apps, 1):
|
||||
mode = app.mode.value if hasattr(app.mode, "value") else app.mode
|
||||
click.echo(f"{index}. {app.name} [{mode}] ({app.id})")
|
||||
app_ids = parse_index_selection(
|
||||
click.prompt("Select apps by number, comma-separated numbers, or all", default="all"),
|
||||
[app.id for app in apps],
|
||||
)
|
||||
selected_apps = [app for app in apps if app.id in set(app_ids)]
|
||||
click.echo("Selected apps:")
|
||||
for app in selected_apps:
|
||||
click.echo(f"- {app.name} ({app.id})")
|
||||
return app_ids
|
||||
|
||||
|
||||
def _prompt_import_options() -> tuple[bool, bool, str, str]:
|
||||
_print_wizard_step("Import Options")
|
||||
_print_wizard_substep("Secrets")
|
||||
click.echo("Secrets include workflow/app DSL secret values, custom API tool credentials,")
|
||||
click.echo("and full MCP provider connection data such as server URL, headers, authentication, and tool list.")
|
||||
click.echo("If you choose no, credentials are omitted or masked,")
|
||||
click.echo("and MCP providers are exported as dependency metadata only.")
|
||||
click.echo("Treat the output JSON as sensitive if you choose yes.")
|
||||
include_secrets = click.confirm(
|
||||
"Include secrets in output JSON? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
)
|
||||
_print_wizard_substep("App API Tokens")
|
||||
click.echo("When enabled, import will create an app API token if the imported app has none,")
|
||||
click.echo("or reuse an existing app API token if one already exists.")
|
||||
create_tokens = click.confirm(
|
||||
"Create or reuse app API tokens during import? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
)
|
||||
_print_wizard_substep("ID Strategy")
|
||||
click.echo("ID strategy controls whether imported app and tool IDs preserve source IDs")
|
||||
click.echo("or use target-generated IDs.")
|
||||
click.echo("preserve-id: keep source IDs where the target service supports it.")
|
||||
click.echo("generate-new-id: let the target environment generate new IDs and rewrite references via mapping.")
|
||||
id_strategy = click.prompt(
|
||||
"Import ID strategy. Enter one of: preserve-id, generate-new-id",
|
||||
type=click.Choice(ID_STRATEGY_CHOICES),
|
||||
default="preserve-id",
|
||||
show_default=True,
|
||||
)
|
||||
_print_wizard_substep("Conflict Strategy")
|
||||
click.echo("Conflict strategy controls what import does when a target resource already exists.")
|
||||
click.echo("fail: stop at the first conflict; previously committed resources are not rolled back.")
|
||||
click.echo("skip: keep the existing target resource and skip importing that resource.")
|
||||
click.echo("update: update the existing target resource in place.")
|
||||
conflict_strategy = click.prompt(
|
||||
"Import conflict strategy. Enter one of: fail, skip, update",
|
||||
type=click.Choice(CONFLICT_STRATEGY_CHOICES),
|
||||
default="update",
|
||||
show_default=True,
|
||||
)
|
||||
return include_secrets, create_tokens, id_strategy, conflict_strategy
|
||||
|
||||
|
||||
def _discover_auto_tools(apps: list[App], include_referenced_tools: bool) -> WizardToolMap:
|
||||
auto_tools: WizardToolMap = {"api_tools": {}, "workflow_tools": {}, "mcp_tools": {}}
|
||||
if not include_referenced_tools:
|
||||
return auto_tools
|
||||
discovery_service = DependencyDiscoveryService()
|
||||
for app in apps:
|
||||
dsl_content = AppDslService.export_dsl(app_model=app, include_secret=False)
|
||||
raw_dsl = yaml.safe_load(dsl_content) if dsl_content else {}
|
||||
dsl = raw_dsl if isinstance(raw_dsl, dict) else {}
|
||||
for dependency in discovery_service.discover_from_dsl(dsl):
|
||||
if dependency.kind == DependencyKind.API_TOOL:
|
||||
auto_tools["api_tools"][dependency.provider_name or dependency.provider_id] = dependency.provider_id
|
||||
elif dependency.kind == DependencyKind.WORKFLOW_TOOL:
|
||||
auto_tools["workflow_tools"][dependency.provider_name or dependency.provider_id] = (
|
||||
dependency.provider_id
|
||||
)
|
||||
elif dependency.kind == DependencyKind.MCP_TOOL:
|
||||
auto_tools["mcp_tools"][dependency.provider_name or dependency.provider_id] = dependency.provider_id
|
||||
return auto_tools
|
||||
|
||||
|
||||
def _resolve_auto_tool_names(tenant_id: str, auto_tools: WizardToolMap) -> WizardToolMap:
|
||||
return {
|
||||
"api_tools": _resolve_api_tool_names(tenant_id, auto_tools["api_tools"]),
|
||||
"workflow_tools": _resolve_workflow_tool_names(tenant_id, auto_tools["workflow_tools"]),
|
||||
"mcp_tools": _resolve_mcp_tool_names(tenant_id, auto_tools["mcp_tools"]),
|
||||
}
|
||||
|
||||
|
||||
def _resolve_api_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
|
||||
resolved: dict[str, str | None] = {}
|
||||
for name, identifier in tools.items():
|
||||
predicates = [ApiToolProvider.name == name]
|
||||
if _is_uuid_string(identifier):
|
||||
predicates.append(ApiToolProvider.id == identifier)
|
||||
provider = db.session.scalar(
|
||||
sa.select(ApiToolProvider).where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
sa.or_(*predicates),
|
||||
)
|
||||
)
|
||||
resolved[provider.name if provider else name] = provider.id if provider else identifier
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_workflow_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
|
||||
resolved: dict[str, str | None] = {}
|
||||
for name, identifier in tools.items():
|
||||
predicates = [WorkflowToolProvider.name == name]
|
||||
if _is_uuid_string(identifier):
|
||||
predicates.append(WorkflowToolProvider.id == identifier)
|
||||
provider = db.session.scalar(
|
||||
sa.select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
sa.or_(*predicates),
|
||||
)
|
||||
)
|
||||
resolved[provider.name if provider else name] = provider.id if provider else identifier
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_mcp_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
|
||||
resolved: dict[str, str | None] = {}
|
||||
for name, identifier in tools.items():
|
||||
predicates = [MCPToolProvider.name == name]
|
||||
if identifier:
|
||||
predicates.append(MCPToolProvider.server_identifier == identifier)
|
||||
if _is_uuid_string(identifier):
|
||||
predicates.append(MCPToolProvider.id == identifier)
|
||||
provider = db.session.scalar(
|
||||
sa.select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
sa.or_(*predicates),
|
||||
)
|
||||
)
|
||||
resolved[provider.name if provider else name] = provider.id if provider else identifier
|
||||
return resolved
|
||||
|
||||
|
||||
def _is_uuid_string(value: str | None) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _print_auto_tools(auto_tools: WizardToolMap) -> None:
|
||||
_print_wizard_step("Automatically Discovered Tools")
|
||||
click.echo("Automatically discovered tools:")
|
||||
_print_auto_tool_category("Custom API tools", auto_tools["api_tools"])
|
||||
_print_auto_tool_category("Workflow tools", auto_tools["workflow_tools"])
|
||||
_print_auto_tool_category("MCP tools", auto_tools["mcp_tools"])
|
||||
|
||||
|
||||
def _print_auto_tool_category(label: str, values: dict[str, str | None]) -> None:
|
||||
click.echo(label)
|
||||
if not values:
|
||||
click.echo("- none")
|
||||
return
|
||||
for name, identifier in sorted(values.items()):
|
||||
click.echo(f"- {_format_tool_name_id(name, identifier)}")
|
||||
|
||||
|
||||
def _prompt_additional_tools(tenant_id: str, auto_tools: WizardToolMap) -> WizardToolSelection:
|
||||
selections: WizardToolSelection = {"api_tools": [], "workflow_tools": [], "mcp_tools": []}
|
||||
_print_wizard_step("Additional Tools")
|
||||
if not click.confirm(
|
||||
"Export additional tools manually? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
):
|
||||
_print_final_tool_selection(auto_tools, selections, {})
|
||||
return selections
|
||||
manual_labels: dict[str, str] = {}
|
||||
api_tool_options = [
|
||||
(tool.name, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).order_by(ApiToolProvider.name)
|
||||
).all()
|
||||
]
|
||||
selections["api_tools"] = _prompt_tool_category(
|
||||
"Custom API tools",
|
||||
api_tool_options,
|
||||
auto_tools=auto_tools["api_tools"],
|
||||
)
|
||||
manual_labels.update(_selected_tool_labels(api_tool_options, selections["api_tools"]))
|
||||
workflow_tool_options = [
|
||||
(tool.id, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
.order_by(WorkflowToolProvider.name)
|
||||
).all()
|
||||
]
|
||||
selections["workflow_tools"] = _prompt_tool_category(
|
||||
"Workflow tools",
|
||||
workflow_tool_options,
|
||||
auto_tools=auto_tools["workflow_tools"],
|
||||
)
|
||||
manual_labels.update(_selected_tool_labels(workflow_tool_options, selections["workflow_tools"]))
|
||||
mcp_tool_options = [
|
||||
(tool.id, tool.name, tool.server_identifier)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
|
||||
).all()
|
||||
]
|
||||
selections["mcp_tools"] = _prompt_tool_category(
|
||||
"MCP tools",
|
||||
mcp_tool_options,
|
||||
auto_tools=auto_tools["mcp_tools"],
|
||||
)
|
||||
manual_labels.update(_selected_tool_labels(mcp_tool_options, selections["mcp_tools"]))
|
||||
_print_final_tool_selection(auto_tools, selections, manual_labels)
|
||||
return selections
|
||||
|
||||
|
||||
def _selected_tool_labels_for_tenant(tenant_id: str, selected_tools: WizardToolSelection) -> dict[str, str]:
|
||||
labels: dict[str, str] = {}
|
||||
if selected_tools["api_tools"]:
|
||||
labels.update(
|
||||
_selected_tool_labels(
|
||||
[
|
||||
(tool.name, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(ApiToolProvider)
|
||||
.where(ApiToolProvider.tenant_id == tenant_id)
|
||||
.order_by(ApiToolProvider.name)
|
||||
).all()
|
||||
],
|
||||
selected_tools["api_tools"],
|
||||
)
|
||||
)
|
||||
if selected_tools["workflow_tools"]:
|
||||
labels.update(
|
||||
_selected_tool_labels(
|
||||
[
|
||||
(tool.id, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
.order_by(WorkflowToolProvider.name)
|
||||
).all()
|
||||
],
|
||||
selected_tools["workflow_tools"],
|
||||
)
|
||||
)
|
||||
if selected_tools["mcp_tools"]:
|
||||
labels.update(
|
||||
_selected_tool_labels(
|
||||
[
|
||||
(tool.id, tool.name, tool.server_identifier)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(MCPToolProvider)
|
||||
.where(MCPToolProvider.tenant_id == tenant_id)
|
||||
.order_by(MCPToolProvider.name)
|
||||
).all()
|
||||
],
|
||||
selected_tools["mcp_tools"],
|
||||
)
|
||||
)
|
||||
return labels
|
||||
|
||||
|
||||
def _selected_tool_labels(options: list[tuple[str, str, str]], selected_values: list[str]) -> dict[str, str]:
|
||||
selected = set(selected_values)
|
||||
return {value: _format_tool_name_id(name, detail) for value, name, detail in options if value in selected}
|
||||
|
||||
|
||||
def _prompt_tool_category(
|
||||
label: str,
|
||||
options: list[tuple[str, str, str]],
|
||||
*,
|
||||
auto_tools: dict[str, str | None],
|
||||
) -> list[str]:
|
||||
if not options:
|
||||
click.echo(f"{label}: none")
|
||||
return []
|
||||
_print_wizard_step(label)
|
||||
for index, (value, name, detail) in enumerate(options, 1):
|
||||
marker = "[auto]" if _is_auto_tool(value, name, detail, auto_tools) else "[ ]"
|
||||
click.echo(f"{index}. {marker} {name} ({detail})")
|
||||
raw = click.prompt(
|
||||
f"Select {label.lower()} by number, comma-separated numbers, all, or empty",
|
||||
default="",
|
||||
show_default=cast(Any, "empty"),
|
||||
)
|
||||
if not raw.strip():
|
||||
return []
|
||||
return parse_index_selection(raw, [value for value, _, _ in options])
|
||||
|
||||
|
||||
def _is_auto_tool(value: str, name: str, detail: str, auto_tools: dict[str, str | None]) -> bool:
|
||||
return name in auto_tools or value in auto_tools or value in auto_tools.values() or detail in auto_tools.values()
|
||||
|
||||
|
||||
def _print_final_tool_selection(
|
||||
auto_tools: WizardToolMap,
|
||||
additional_tools: WizardToolSelection,
|
||||
manual_labels: dict[str, str],
|
||||
) -> None:
|
||||
_print_wizard_step("Final Tool Selection")
|
||||
_print_tool_selection_body(auto_tools, additional_tools, manual_labels)
|
||||
|
||||
|
||||
def _print_tool_selection_body(
|
||||
auto_tools: WizardToolMap,
|
||||
additional_tools: WizardToolSelection,
|
||||
manual_labels: dict[str, str],
|
||||
) -> None:
|
||||
click.echo("Final tools to export:")
|
||||
_print_final_tool_category(
|
||||
"Custom API tools",
|
||||
auto_tools["api_tools"],
|
||||
additional_tools["api_tools"],
|
||||
manual_labels,
|
||||
)
|
||||
_print_final_tool_category(
|
||||
"Workflow tools",
|
||||
auto_tools["workflow_tools"],
|
||||
additional_tools["workflow_tools"],
|
||||
manual_labels,
|
||||
)
|
||||
_print_final_tool_category("MCP tools", auto_tools["mcp_tools"], additional_tools["mcp_tools"], manual_labels)
|
||||
|
||||
|
||||
def _print_final_tool_category(
|
||||
label: str,
|
||||
auto_tools: dict[str, str | None],
|
||||
manual_values: list[str],
|
||||
manual_labels: dict[str, str],
|
||||
) -> None:
|
||||
click.echo(label)
|
||||
lines = [f"- [auto] {_format_tool_name_id(name, identifier)}" for name, identifier in sorted(auto_tools.items())]
|
||||
auto_identifiers = {identifier for identifier in auto_tools.values() if identifier}
|
||||
lines.extend(
|
||||
f"- [manual] {manual_labels.get(value, value)}"
|
||||
for value in manual_values
|
||||
if value not in auto_tools and value not in auto_identifiers
|
||||
)
|
||||
if not lines:
|
||||
click.echo("- none")
|
||||
return
|
||||
for line in lines:
|
||||
click.echo(line)
|
||||
|
||||
|
||||
def _format_tool_name_id(name: str, identifier: str | None) -> str:
|
||||
if identifier and identifier != name:
|
||||
return f"{name}: {identifier}"
|
||||
return name
|
||||
|
||||
|
||||
def _confirm_wizard_summary(
|
||||
*,
|
||||
tenant_name: str,
|
||||
app_names: list[str],
|
||||
auto_tools: WizardToolMap,
|
||||
additional_tools: WizardToolSelection,
|
||||
manual_labels: dict[str, str],
|
||||
include_referenced_tools: bool,
|
||||
include_secrets: bool,
|
||||
create_tokens: bool,
|
||||
id_strategy: str,
|
||||
conflict_strategy: str,
|
||||
output_file: str,
|
||||
) -> None:
|
||||
_print_wizard_step("Summary")
|
||||
click.echo("Migration export summary:")
|
||||
click.echo(f"source tenant: {tenant_name}")
|
||||
click.echo(f"selected apps: {len(app_names)}")
|
||||
for app_name in app_names:
|
||||
click.echo(f"- {app_name}")
|
||||
click.echo(f"auto referenced tools: {str(include_referenced_tools).lower()}")
|
||||
_print_tool_selection_body(auto_tools, additional_tools, manual_labels)
|
||||
click.echo(f"include secrets: {str(include_secrets).lower()}")
|
||||
click.echo(f"create app api token on import: {str(create_tokens).lower()}")
|
||||
click.echo(f"id strategy: {id_strategy}")
|
||||
click.echo(f"conflict strategy: {conflict_strategy}")
|
||||
click.echo(f"output path: {output_file}")
|
||||
if not click.confirm("Write migration package? [y/n, default: y]", default=True, show_default=False):
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
def _prompt_output_file() -> tuple[str, bool]:
|
||||
default_output = f"migration-data-{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
output_file = click.prompt("Output path", default=default_output, show_default=True)
|
||||
if output_file.lower() in {"y", "yes", "n", "no"}:
|
||||
raise click.ClickException("Output path must be a file path. Press Enter to use the default path.")
|
||||
overwrite = False
|
||||
if Path(output_file).exists():
|
||||
overwrite = click.confirm(
|
||||
"Output file exists. Overwrite? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
)
|
||||
if not overwrite:
|
||||
raise click.ClickException(f"Output file already exists: {output_file}")
|
||||
return output_file, overwrite
|
||||
|
||||
|
||||
def _with_output_path(context: ReportContext | None, output_path: str) -> ReportContext:
|
||||
if context is None:
|
||||
return ReportContext(output_path=output_path)
|
||||
return ReportContext(
|
||||
output_path=output_path,
|
||||
source_scope=context.source_scope,
|
||||
selected_app_count=context.selected_app_count,
|
||||
include_secrets=context.include_secrets,
|
||||
target_tenant=context.target_tenant,
|
||||
operator_email=context.operator_email,
|
||||
app_api_tokens_created=context.app_api_tokens_created,
|
||||
app_api_tokens_reused=context.app_api_tokens_reused,
|
||||
id_mapping_count=context.id_mapping_count,
|
||||
id_mappings=context.id_mappings,
|
||||
)
|
||||
|
||||
|
||||
def _render_report(report_items: list[ResourceReportItem], *, context: ReportContext | None = None) -> None:
|
||||
for line in MigrationReportService().render(report_items, context=context):
|
||||
click.echo(line)
|
||||
@ -30,7 +30,7 @@ def vdb_migrate(scope: str):
|
||||
|
||||
def migrate_annotation_vector_database():
|
||||
"""
|
||||
Migrate annotation datas to target vector database .
|
||||
Migrate annotation data to target vector database.
|
||||
"""
|
||||
click.echo(click.style("Starting annotation data migration.", fg="green"))
|
||||
create_count = 0
|
||||
@ -140,7 +140,7 @@ def migrate_annotation_vector_database():
|
||||
|
||||
def migrate_knowledge_vector_database():
|
||||
"""
|
||||
Migrate vector database datas to target vector database .
|
||||
Migrate vector database data to target vector database.
|
||||
"""
|
||||
click.echo(click.style("Starting vector database migration.", fg="green"))
|
||||
create_count = 0
|
||||
|
||||
@ -29,6 +29,7 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
current_state = self.current_state
|
||||
remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")
|
||||
|
||||
@ -81,4 +81,15 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
||||
},
|
||||
},
|
||||
},
|
||||
# agent default mode (new Agent App type). The runtime model / prompt / tools
|
||||
# come from the bound Agent Soul snapshot, so no model_config is seeded in the
|
||||
# template; create_app still creates a model-less app_model_config row to hold
|
||||
# app-level presentation features (opener, follow-up, citations, ...).
|
||||
AppMode.AGENT: {
|
||||
"app": {
|
||||
"mode": AppMode.AGENT,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -6,10 +6,11 @@ These helpers keep that translation centralized so models registered through
|
||||
`register_schema_models` emit resolvable Swagger 2.0 references.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Iterable, Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, NotRequired, TypedDict
|
||||
from typing import Any, Literal, NotRequired, Protocol, TypedDict
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
@ -36,6 +37,12 @@ QueryParamDoc = TypedDict(
|
||||
)
|
||||
|
||||
|
||||
class QueryArgs(Protocol):
|
||||
def to_dict(self, flat: bool = True) -> dict[str, str]: ...
|
||||
|
||||
def getlist(self, key: str) -> list[str]: ...
|
||||
|
||||
|
||||
def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None:
|
||||
"""Register a JSON schema and promote any nested Pydantic `$defs`."""
|
||||
|
||||
@ -167,6 +174,58 @@ def query_params_from_model(model: type[BaseModel]) -> dict[str, QueryParamDoc]:
|
||||
return params
|
||||
|
||||
|
||||
def query_params_from_request[ModelT: BaseModel](
|
||||
model: type[ModelT],
|
||||
*,
|
||||
list_fields: Iterable[str] = (),
|
||||
args: QueryArgs | None = None,
|
||||
use_defaults_for_malformed_ints: bool = False,
|
||||
) -> ModelT:
|
||||
"""Validate query args with Pydantic while preserving Flask query parsing behavior.
|
||||
|
||||
Repeated params need explicit ``getlist()`` handling because Werkzeug's
|
||||
``to_dict()`` keeps only one value. For malformed scalar integers, Flask's
|
||||
For endpoints migrated from ``request.args.get(..., type=int, default=...)``,
|
||||
set ``use_defaults_for_malformed_ints`` to preserve Flask's fallback to
|
||||
defaults for malformed optional integer params.
|
||||
"""
|
||||
|
||||
query_args = args or request.args
|
||||
params: dict[str, Any] = query_args.to_dict()
|
||||
for field_name in list_fields:
|
||||
params[field_name] = query_args.getlist(field_name)
|
||||
|
||||
if use_defaults_for_malformed_ints:
|
||||
_drop_malformed_defaulted_integer_params(model, params)
|
||||
return model.model_validate(params)
|
||||
|
||||
|
||||
def _drop_malformed_defaulted_integer_params(model: type[BaseModel], params: dict[str, Any]) -> None:
|
||||
properties = model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0).get("properties", {})
|
||||
if not isinstance(properties, Mapping):
|
||||
return
|
||||
|
||||
for name, value in list(params.items()):
|
||||
if not isinstance(value, str):
|
||||
continue
|
||||
|
||||
field = model.model_fields.get(name)
|
||||
if field is None or field.is_required():
|
||||
continue
|
||||
|
||||
property_schema = properties.get(name)
|
||||
if not isinstance(property_schema, Mapping):
|
||||
continue
|
||||
|
||||
if _nullable_property_schema(property_schema).get("type") != "integer":
|
||||
continue
|
||||
|
||||
try:
|
||||
int(value)
|
||||
except ValueError:
|
||||
params.pop(name)
|
||||
|
||||
|
||||
def _query_param_from_property(property_schema: Mapping[str, Any], *, required: bool) -> QueryParamDoc:
|
||||
param_schema = _nullable_property_schema(property_schema)
|
||||
param_doc: QueryParamDoc = {"in": "query", "required": required}
|
||||
@ -239,6 +298,7 @@ __all__ = [
|
||||
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
|
||||
"get_or_create_model",
|
||||
"query_params_from_model",
|
||||
"query_params_from_request",
|
||||
"register_enum_models",
|
||||
"register_response_schema_model",
|
||||
"register_response_schema_models",
|
||||
|
||||
@ -51,6 +51,8 @@ from .agent import roster as agent_roster
|
||||
from .app import (
|
||||
advanced_prompt_template,
|
||||
agent,
|
||||
agent_app_access,
|
||||
agent_app_feature,
|
||||
annotation,
|
||||
app,
|
||||
audio,
|
||||
@ -146,6 +148,8 @@ __all__ = [
|
||||
"activate",
|
||||
"advanced_prompt_template",
|
||||
"agent",
|
||||
"agent_app_access",
|
||||
"agent_app_feature",
|
||||
"agent_composer",
|
||||
"agent_providers",
|
||||
"agent_roster",
|
||||
|
||||
@ -1,53 +1,91 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from fields.agent_fields import (
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerCandidatesResponse,
|
||||
AgentComposerImpactResponse,
|
||||
AgentComposerValidateResponse,
|
||||
WorkflowAgentComposerResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent.composer_service import AgentComposerService
|
||||
from services.agent.composer_validator import ComposerConfigValidator
|
||||
from services.entities.agent_entities import ComposerSavePayload
|
||||
|
||||
register_schema_models(console_ns, ComposerSavePayload)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerCandidatesResponse,
|
||||
AgentComposerImpactResponse,
|
||||
AgentComposerValidateResponse,
|
||||
WorkflowAgentComposerResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer")
|
||||
class WorkflowAgentComposerApi(Resource):
|
||||
@console_ns.response(
|
||||
200, "Workflow agent composer state", console_ns.models[WorkflowAgentComposerResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return AgentComposerService.load_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App, node_id: str):
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
AgentComposerService.load_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
),
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(
|
||||
200, "Workflow agent composer saved", console_ns.models[WorkflowAgentComposerResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def put(self, app_model: App, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def put(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account.id,
|
||||
payload=payload,
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account_id,
|
||||
payload=payload,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/validate")
|
||||
class WorkflowAgentComposerValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(
|
||||
200, "Workflow agent composer validation result", console_ns.models[AgentComposerValidateResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -55,84 +93,116 @@ class WorkflowAgentComposerValidateApi(Resource):
|
||||
def post(self, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
ComposerConfigValidator.validate_save_payload(payload)
|
||||
return {"result": "success", "errors": []}
|
||||
return dump_response(AgentComposerValidateResponse, {"result": "success", "errors": []})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/candidates")
|
||||
class WorkflowAgentComposerCandidatesApi(Resource):
|
||||
@console_ns.response(
|
||||
200, "Workflow agent composer candidates", console_ns.models[AgentComposerCandidatesResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
return AgentComposerService.get_workflow_candidates(app_id=app_model.id)
|
||||
return dump_response(
|
||||
AgentComposerCandidatesResponse,
|
||||
AgentComposerService.get_workflow_candidates(app_id=app_model.id),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/impact")
|
||||
class WorkflowAgentComposerImpactApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(200, "Workflow agent composer impact", console_ns.models[AgentComposerImpactResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None
|
||||
if not current_snapshot_id:
|
||||
return {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
|
||||
return AgentComposerService.calculate_impact(tenant_id=tenant_id, current_snapshot_id=current_snapshot_id)
|
||||
return dump_response(
|
||||
AgentComposerImpactResponse, {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
|
||||
)
|
||||
return dump_response(
|
||||
AgentComposerImpactResponse,
|
||||
AgentComposerService.calculate_impact(tenant_id=tenant_id, current_snapshot_id=current_snapshot_id),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/save-to-roster")
|
||||
class WorkflowAgentComposerSaveToRosterApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(
|
||||
200, "Workflow agent composer saved to roster", console_ns.models[WorkflowAgentComposerResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account.id,
|
||||
payload=payload,
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account_id,
|
||||
payload=payload,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-composer")
|
||||
class AgentAppComposerApi(Resource):
|
||||
@console_ns.response(200, "Agent app composer state", console_ns.models[AgentAppComposerResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
return dump_response(
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id),
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(200, "Agent app composer saved", console_ns.models[AgentAppComposerResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model()
|
||||
def put(self, app_model: App):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def put(self, tenant_id: str, account_id: str, app_model: App):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_agent_app_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
account_id=account.id,
|
||||
payload=payload,
|
||||
return dump_response(
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerService.save_agent_app_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
account_id=account_id,
|
||||
payload=payload,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-composer/validate")
|
||||
class AgentAppComposerValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(
|
||||
200, "Agent app composer validation result", console_ns.models[AgentComposerValidateResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -140,14 +210,20 @@ class AgentAppComposerValidateApi(Resource):
|
||||
def post(self, app_model: App):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
ComposerConfigValidator.validate_save_payload(payload)
|
||||
return {"result": "success", "errors": []}
|
||||
return dump_response(AgentComposerValidateResponse, {"result": "success", "errors": []})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-composer/candidates")
|
||||
class AgentAppComposerCandidatesApi(Resource):
|
||||
@console_ns.response(
|
||||
200, "Agent app composer candidates", console_ns.models[AgentComposerCandidatesResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
return AgentComposerService.get_agent_app_candidates(app_id=app_model.id)
|
||||
return dump_response(
|
||||
AgentComposerCandidatesResponse,
|
||||
AgentComposerService.get_agent_app_candidates(app_id=app_model.id),
|
||||
)
|
||||
|
||||
@ -4,11 +4,25 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from fields.agent_fields import (
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
AgentConfigSnapshotListResponse,
|
||||
AgentInviteOptionsResponse,
|
||||
AgentRosterListResponse,
|
||||
AgentRosterResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from services.agent.roster_service import AgentRosterService
|
||||
from services.entities.agent_entities import RosterAgentCreatePayload, RosterAgentUpdatePayload, RosterListQuery
|
||||
|
||||
@ -29,6 +43,14 @@ register_schema_models(
|
||||
RosterAgentUpdatePayload,
|
||||
RosterListQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
AgentConfigSnapshotListResponse,
|
||||
AgentInviteOptionsResponse,
|
||||
AgentRosterListResponse,
|
||||
AgentRosterResponse,
|
||||
)
|
||||
|
||||
|
||||
def _agent_roster_service() -> AgentRosterService:
|
||||
@ -37,96 +59,130 @@ def _agent_roster_service() -> AgentRosterService:
|
||||
|
||||
@console_ns.route("/agents")
|
||||
class AgentRosterListApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(RosterListQuery))
|
||||
@console_ns.response(200, "Agent roster list", console_ns.models[AgentRosterListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
query = RosterListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
return _agent_roster_service().list_roster_agents(
|
||||
tenant_id=tenant_id, page=query.page, limit=query.limit, keyword=query.keyword
|
||||
return dump_response(
|
||||
AgentRosterListResponse,
|
||||
_agent_roster_service().list_roster_agents(
|
||||
tenant_id=tenant_id, page=query.page, limit=query.limit, keyword=query.keyword
|
||||
),
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[RosterAgentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Agent created", console_ns.models[AgentRosterResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, account_id: str):
|
||||
payload = RosterAgentCreatePayload.model_validate(console_ns.payload or {})
|
||||
service = _agent_roster_service()
|
||||
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account.id, payload=payload)
|
||||
return service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id), 201
|
||||
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account_id, payload=payload)
|
||||
return dump_response(
|
||||
AgentRosterResponse,
|
||||
service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id),
|
||||
), 201
|
||||
|
||||
|
||||
@console_ns.route("/agents/invite-options")
|
||||
class AgentInviteOptionsApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(AgentInviteOptionsQuery))
|
||||
@console_ns.response(200, "Agent invite options", console_ns.models[AgentInviteOptionsResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
query = AgentInviteOptionsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
return _agent_roster_service().list_invite_options(
|
||||
tenant_id=tenant_id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
app_id=query.app_id,
|
||||
return dump_response(
|
||||
AgentInviteOptionsResponse,
|
||||
_agent_roster_service().list_invite_options(
|
||||
tenant_id=tenant_id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
app_id=query.app_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/agents/<uuid:agent_id>")
|
||||
class AgentRosterDetailApi(Resource):
|
||||
@console_ns.response(200, "Agent detail", console_ns.models[AgentRosterResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return _agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id))
|
||||
|
||||
@console_ns.expect(console_ns.models[RosterAgentUpdatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, agent_id: UUID):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
return _agent_roster_service().update_roster_agent(
|
||||
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id, payload=payload
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
return dump_response(
|
||||
AgentRosterResponse,
|
||||
_agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id)),
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[RosterAgentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Agent updated", console_ns.models[AgentRosterResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, agent_id: UUID):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, account_id: str, agent_id: UUID):
|
||||
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
return dump_response(
|
||||
AgentRosterResponse,
|
||||
_agent_roster_service().update_roster_agent(
|
||||
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id, payload=payload
|
||||
),
|
||||
)
|
||||
|
||||
@console_ns.response(204, "Agent archived")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def delete(self, tenant_id: str, account_id: str, agent_id: UUID):
|
||||
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id)
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/agents/<uuid:agent_id>/versions")
|
||||
class AgentRosterVersionsApi(Resource):
|
||||
@console_ns.response(200, "Agent versions", console_ns.models[AgentConfigSnapshotListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return {"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))}
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
return dump_response(
|
||||
AgentConfigSnapshotListResponse,
|
||||
{"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/agents/<uuid:agent_id>/versions/<uuid:version_id>")
|
||||
class AgentRosterVersionDetailApi(Resource):
|
||||
@console_ns.response(200, "Agent version detail", console_ns.models[AgentConfigSnapshotDetailResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID, version_id: UUID):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return _agent_roster_service().get_agent_version_detail(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=str(agent_id),
|
||||
version_id=str(version_id),
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID, version_id: UUID):
|
||||
return dump_response(
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
_agent_roster_service().get_agent_version_detail(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=str(agent_id),
|
||||
version_id=str(version_id),
|
||||
),
|
||||
)
|
||||
|
||||
59
api/controllers/console/app/agent_app_access.py
Normal file
59
api/controllers/console/app/agent_app_access.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""Agent App access & sharing endpoints (read-only workflow references).
|
||||
|
||||
An Agent App is backed by a roster Agent that workflow Agent nodes may also
|
||||
reference. This exposes the read-only "Workflow access" surface from the PRD:
|
||||
which workflow apps use this Agent, without leaking the workflows' internals.
|
||||
"""
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent.roster_service import AgentRosterService
|
||||
|
||||
|
||||
class AgentReferencingWorkflowResponse(ResponseModel):
|
||||
app_id: str
|
||||
app_name: str
|
||||
app_mode: str
|
||||
workflow_id: str
|
||||
node_ids: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentReferencingWorkflowsResponse(ResponseModel):
|
||||
data: list[AgentReferencingWorkflowResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_response_schema_models(console_ns, AgentReferencingWorkflowsResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-referencing-workflows")
|
||||
class AgentAppReferencingWorkflowsResource(Resource):
|
||||
@console_ns.doc("list_agent_app_referencing_workflows")
|
||||
@console_ns.doc(description="List workflow apps that reference this Agent App's bound Agent (read-only)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Referencing workflows listed successfully",
|
||||
console_ns.models[AgentReferencingWorkflowsResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
workflows = AgentRosterService(db.session).list_workflows_referencing_app_agent(
|
||||
tenant_id=tenant_id, app_id=app_model.id
|
||||
)
|
||||
return AgentReferencingWorkflowsResponse(
|
||||
data=[AgentReferencingWorkflowResponse.model_validate(workflow) for workflow in workflows]
|
||||
).model_dump(mode="json")
|
||||
93
api/controllers/console/app/agent_app_feature.py
Normal file
93
api/controllers/console/app/agent_app_feature.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Agent App presentation-feature configuration endpoint.
|
||||
|
||||
The new Agent App type keeps model / prompt / tools in its bound Agent Soul, so
|
||||
the legacy ``/model-config`` surface (which writes model, prompt and agent tool
|
||||
config) is the wrong place to configure its app-level presentation features.
|
||||
This endpoint exposes only the PRD "Misc Legacy" feature subset — conversation
|
||||
opener, follow-up suggestions, citations, content moderation and speech — and
|
||||
persists them onto the app's ``app_model_config`` without touching anything the
|
||||
Soul owns.
|
||||
"""
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.agent_config_entities import (
|
||||
AgentFeatureToggleConfig,
|
||||
AgentSensitiveWordAvoidanceFeatureConfig,
|
||||
AgentSuggestedQuestionsAfterAnswerFeatureConfig,
|
||||
AgentTextToSpeechFeatureConfig,
|
||||
)
|
||||
from models.model import App, AppMode
|
||||
from services.agent_app_feature_service import AgentAppFeatureConfigService
|
||||
|
||||
|
||||
class AgentAppFeaturesPayload(BaseModel):
|
||||
"""Presentation features configurable on an Agent App.
|
||||
|
||||
All fields are optional; an omitted field is reset to its disabled/empty
|
||||
default (the config form sends the full desired feature state on save).
|
||||
"""
|
||||
|
||||
opening_statement: str | None = Field(default=None, description="Conversation opener shown before the first turn")
|
||||
suggested_questions: list[str] | None = Field(
|
||||
default=None, description="Preset questions shown alongside the opener"
|
||||
)
|
||||
suggested_questions_after_answer: AgentSuggestedQuestionsAfterAnswerFeatureConfig | None = Field(
|
||||
default=None, description="Follow-up suggestions config, e.g. {'enabled': true}"
|
||||
)
|
||||
speech_to_text: AgentFeatureToggleConfig | None = Field(default=None, description="Speech-to-text config")
|
||||
text_to_speech: AgentTextToSpeechFeatureConfig | None = Field(default=None, description="Text-to-speech config")
|
||||
retriever_resource: AgentFeatureToggleConfig | None = Field(
|
||||
default=None, description="Citations / attributions config, e.g. {'enabled': true}"
|
||||
)
|
||||
sensitive_word_avoidance: AgentSensitiveWordAvoidanceFeatureConfig | None = Field(
|
||||
default=None, description="Content moderation config"
|
||||
)
|
||||
|
||||
|
||||
register_schema_models(console_ns, AgentAppFeaturesPayload)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-features")
|
||||
class AgentAppFeatureConfigResource(Resource):
|
||||
@console_ns.doc("update_agent_app_features")
|
||||
@console_ns.doc(description="Update an Agent App's presentation features (opener, follow-up, citations, ...)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AgentAppFeaturesPayload.__name__])
|
||||
@console_ns.response(200, "Features updated successfully", console_ns.models[SimpleResultResponse.__name__])
|
||||
@console_ns.response(400, "Invalid configuration")
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args = AgentAppFeaturesPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
new_app_model_config = AgentAppFeatureConfigService.update_features(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
config=args.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||
|
||||
return dump_response(SimpleResultResponse, {"result": "success"})
|
||||
@ -25,6 +25,9 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
with_current_user_id,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
@ -34,8 +37,8 @@ from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from libs.helper import build_icon_url, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from libs.login import login_required
|
||||
from models import Account, App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppListParams, AppService, CreateAppParams
|
||||
@ -55,7 +58,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
@ -66,7 +69,7 @@ _TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
@ -115,7 +118,9 @@ def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str,
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
mode: Literal["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"] = Field(
|
||||
..., description="App mode"
|
||||
)
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
@ -393,6 +398,8 @@ class AppDetailWithSite(AppDetail):
|
||||
max_active_requests: int | None = None
|
||||
deleted_tools: list[DeletedTool] = Field(default_factory=list)
|
||||
site: Site | None = None
|
||||
# For Agent App type: the roster Agent backing this app (None otherwise).
|
||||
bound_agent_id: str | None = None
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
@ -467,10 +474,11 @@ class AppListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
@with_session(write=False)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user_id: str, session: Session):
|
||||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
|
||||
params = AppListParams(
|
||||
page=args.page,
|
||||
@ -483,7 +491,7 @@ class AppListApi(Resource):
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, params)
|
||||
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params)
|
||||
if not app_pagination:
|
||||
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json"), 200
|
||||
@ -504,7 +512,7 @@ class AppListApi(Resource):
|
||||
draft_trigger_app_ids: set[str] = set()
|
||||
if workflow_capable_app_ids:
|
||||
draft_workflows = (
|
||||
db.session.execute(
|
||||
session.execute(
|
||||
select(Workflow).where(
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
Workflow.app_id.in_(workflow_capable_app_ids),
|
||||
@ -543,9 +551,10 @@ class AppListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
"""Create app"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = CreateAppPayload.model_validate(console_ns.payload)
|
||||
params = CreateAppParams(
|
||||
name=args.name,
|
||||
@ -648,11 +657,10 @@ class AppCopyApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""Copy app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
@ -731,7 +739,8 @@ class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, app_model: App):
|
||||
"""Publish app to Creators Platform"""
|
||||
from configs import dify_config
|
||||
from core.helper.creators import get_redirect_url, upload_dsl
|
||||
@ -739,13 +748,11 @@ class AppPublishToCreatorsPlatformApi(Resource):
|
||||
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
return {"error": "Creators Platform features are not enabled"}, 403
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
|
||||
dsl_bytes = dsl_content.encode("utf-8")
|
||||
|
||||
claim_code = upload_dsl(dsl_bytes)
|
||||
redirect_url = get_redirect_url(str(current_user.id), claim_code)
|
||||
redirect_url = get_redirect_url(current_user_id, claim_code)
|
||||
|
||||
return {"redirect_url": redirect_url}
|
||||
|
||||
|
||||
@ -9,9 +9,11 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, Import
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -48,9 +50,9 @@ class AppImportApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = AppImportPayload.model_validate(console_ns.payload)
|
||||
|
||||
# AppDslService performs internal commits for some creation paths, so use a plain
|
||||
@ -97,10 +99,9 @@ class AppImportConfirmApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, import_id: str):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, import_id: str):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Any, Literal
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
@ -19,7 +19,12 @@ from controllers.console.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user_id,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
@ -41,9 +46,24 @@ from services.errors.llm import InvokeRateLimitError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_debugger_chat_streaming(
|
||||
*, app_mode: AppMode, response_mode: str, response_mode_provided: bool = True
|
||||
) -> bool:
|
||||
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
|
||||
if app_mode != AppMode.AGENT:
|
||||
return response_mode != "blocking"
|
||||
if response_mode_provided and response_mode == "blocking":
|
||||
raise BadRequest("Agent App only supports streaming response mode.")
|
||||
return True
|
||||
|
||||
|
||||
class BaseMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config")
|
||||
# Agent Apps (AppMode.AGENT) derive their model + prompt from the bound Agent
|
||||
# Soul, so no override ``model_config`` is sent; chat / agent-chat / completion
|
||||
# debugging still pass it. Optional here, required in practice by those modes
|
||||
# downstream when their config is built from args.
|
||||
model_config_data: dict[str, Any] = Field(default_factory=dict, alias="model_config")
|
||||
files: list[Any] | None = Field(default=None, description="Uploaded files")
|
||||
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
|
||||
retriever_from: str = Field(default="dev", description="Retriever source")
|
||||
@ -131,14 +151,13 @@ class CompletionMessageStopApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model: App, task_id: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, app_model: App, task_id: str):
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
@ -157,13 +176,20 @@ class ChatMessageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
raw_payload = console_ns.payload or {}
|
||||
args_model = ChatMessagePayload.model_validate(raw_payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
streaming = _resolve_debugger_chat_streaming(
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
response_mode=args_model.response_mode,
|
||||
response_mode_provided=isinstance(raw_payload, dict) and "response_mode" in raw_payload,
|
||||
)
|
||||
if AppMode.value_of(app_model.mode) == AppMode.AGENT:
|
||||
args["response_mode"] = "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
@ -211,15 +237,14 @@ class ChatMessageStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, task_id: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, app_model: App, task_id: str):
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
|
||||
@ -12,7 +12,12 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
@ -31,8 +36,9 @@ from fields.conversation_fields import (
|
||||
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -93,8 +99,8 @@ class CompletionConversationApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
query = sa.select(Conversation).where(
|
||||
@ -165,10 +171,11 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App, conversation_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
return ConversationMessageDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_completion_conversation")
|
||||
@ -182,8 +189,8 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, conversation_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
|
||||
try:
|
||||
@ -205,10 +212,10 @@ class ChatConversationApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
subquery = (
|
||||
@ -316,12 +323,13 @@ class ChatConversationDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App, conversation_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
return ConversationDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_chat_conversation")
|
||||
@ -332,11 +340,11 @@ class ChatConversationDetailApi(Resource):
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, conversation_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
|
||||
try:
|
||||
@ -347,8 +355,7 @@ class ChatConversationDetailApi(Resource):
|
||||
return "", 204
|
||||
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def _get_conversation(current_user: Account, app_model, conversation_id):
|
||||
conversation = db.session.scalar(
|
||||
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
@ -2,6 +2,7 @@ from collections.abc import Sequence
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@ -11,6 +12,7 @@ from controllers.console.app.error import (
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.wraps import with_session
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
@ -19,7 +21,6 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
@ -158,7 +159,8 @@ class InstructionGenerateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
@with_session(write=False)
|
||||
def post(self, session: Session, current_tenant_id: str):
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
@ -168,10 +170,10 @@ class InstructionGenerateApi(Resource):
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.get(App, args.flow_id)
|
||||
app = session.get(App, args.flow_id)
|
||||
if not app:
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app, session=session)
|
||||
if not workflow:
|
||||
return {"error": f"workflow {args.flow_id} not found"}, 400
|
||||
nodes: Sequence = workflow.graph_dict["nodes"]
|
||||
|
||||
@ -25,6 +25,7 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
||||
@ -43,7 +44,8 @@ from fields.conversation_fields import (
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import to_timestamp, uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import App, AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -178,7 +180,7 @@ class ChatMessageListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
@ -257,9 +259,8 @@ class MessageFeedbackApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args = MessageFeedbackPayload.model_validate(console_ns.payload)
|
||||
|
||||
message_id = str(args.message_id)
|
||||
@ -336,9 +337,9 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App, message_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, message_id: UUID):
|
||||
message_id_str = str(message_id)
|
||||
|
||||
try:
|
||||
|
||||
@ -8,14 +8,20 @@ from pydantic import BaseModel, Field
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
@ -52,9 +58,10 @@ class ModelConfigResource(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model: App):
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user_id: str, app_model: App):
|
||||
"""Modify app model config"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -64,8 +71,8 @@ class ModelConfigResource(Resource):
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
updated_by=current_user.id,
|
||||
created_by=current_user_id,
|
||||
updated_by=current_user_id,
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
@ -90,7 +97,7 @@ class ModelConfigResource(Resource):
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -130,7 +137,7 @@ class ModelConfigResource(Resource):
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
@ -167,7 +174,7 @@ class ModelConfigResource(Resource):
|
||||
db.session.flush()
|
||||
|
||||
app_model.app_model_config_id = new_app_model_config.id
|
||||
app_model.updated_by = current_user.id
|
||||
app_model.updated_by = current_user_id
|
||||
app_model.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@ -14,12 +14,14 @@ from controllers.console.wraps import (
|
||||
edit_permission_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Site
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
|
||||
|
||||
@ -85,9 +87,9 @@ class AppSite(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
if not site:
|
||||
raise NotFound
|
||||
@ -134,8 +136,8 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
|
||||
@ -8,13 +8,14 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import AppMode
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
|
||||
|
||||
@ -48,9 +49,8 @@ class DailyMessageStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -109,9 +109,8 @@ class DailyConversationStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -169,9 +168,8 @@ class DailyTerminalsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -230,9 +228,8 @@ class DailyTokenCostStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -293,10 +290,9 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||
@ -374,9 +370,8 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
@ -444,9 +439,8 @@ class AverageResponseTimeStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -505,8 +499,8 @@ class TokensPerSecondStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
|
||||
@ -6,10 +6,11 @@ from sqlalchemy.orm import sessionmaker
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import App, AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
@ -46,9 +47,8 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
@ -86,9 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
@ -126,9 +125,8 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
@ -166,9 +164,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
@ -32,11 +32,11 @@ from controllers.console.wraps import (
|
||||
decrypt_password_field,
|
||||
email_password_login_enabled,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
clear_csrf_token_from_cookie,
|
||||
@ -46,6 +46,7 @@ from libs.token import (
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
)
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
|
||||
@ -172,9 +173,8 @@ class LoginApi(Resource):
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
if isinstance(account, flask_login.AnonymousUserMixin):
|
||||
response = make_response({"result": "success"})
|
||||
else:
|
||||
|
||||
@ -8,9 +8,16 @@ from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@ -32,8 +39,9 @@ class Subscription(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True))
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
|
||||
@ -45,8 +53,9 @@ class Invoices(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_invoices(current_user.email, current_tenant_id)
|
||||
|
||||
@ -63,9 +72,8 @@ class PartnerTenants(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def put(self, partner_key: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def put(self, current_user: Account, partner_key: str):
|
||||
try:
|
||||
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
|
||||
click_id = args.click_id
|
||||
|
||||
@ -3,11 +3,18 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from ..wraps import (
|
||||
account_initialization_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
|
||||
|
||||
class ComplianceDownloadQuery(BaseModel):
|
||||
@ -29,8 +36,9 @@ class ComplianceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
|
||||
@ -9,7 +9,7 @@ from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource, marshal
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import asc, desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
@ -34,14 +34,16 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.document_fields import (
|
||||
document_fields,
|
||||
document_status_fields,
|
||||
document_with_segments_fields,
|
||||
DocumentMetadataResponse,
|
||||
DocumentResponse,
|
||||
DocumentStatusListResponse,
|
||||
DocumentStatusResponse,
|
||||
normalize_enum,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import to_timestamp
|
||||
from libs.helper import dump_response, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
@ -74,12 +76,6 @@ from ..wraps import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_enum(value: Any) -> Any:
|
||||
if isinstance(value, str) or value is None:
|
||||
return value
|
||||
return getattr(value, "value", value)
|
||||
|
||||
|
||||
class DatasetResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -93,7 +89,7 @@ class DatasetResponse(ResponseModel):
|
||||
@field_validator("data_source_type", "indexing_technique", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> Any:
|
||||
return _normalize_enum(value)
|
||||
return normalize_enum(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
@ -101,61 +97,10 @@ class DatasetResponse(ResponseModel):
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentMetadataResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
value: str | None = None
|
||||
|
||||
|
||||
class DocumentResponse(ResponseModel):
|
||||
id: str
|
||||
position: int | None = None
|
||||
data_source_type: str | None = None
|
||||
data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict")
|
||||
data_source_detail_dict: Any = None
|
||||
dataset_process_rule_id: str | None = None
|
||||
name: str
|
||||
created_from: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
tokens: int | None = None
|
||||
indexing_status: str | None = None
|
||||
error: str | None = None
|
||||
enabled: bool | None = None
|
||||
disabled_at: int | None = None
|
||||
disabled_by: str | None = None
|
||||
archived: bool | None = None
|
||||
display_status: str | None = None
|
||||
word_count: int | None = None
|
||||
hit_count: int | None = None
|
||||
doc_form: str | None = None
|
||||
doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details")
|
||||
summary_index_status: str | None = None
|
||||
need_summary: bool | None = None
|
||||
|
||||
@field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> Any:
|
||||
return _normalize_enum(value)
|
||||
|
||||
@field_validator("doc_metadata", mode="before")
|
||||
@classmethod
|
||||
def _normalize_doc_metadata(cls, value: Any) -> list[Any]:
|
||||
if value is None:
|
||||
return []
|
||||
return value
|
||||
|
||||
@field_validator("created_at", "disabled_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentWithSegmentsResponse(DocumentResponse):
|
||||
process_rule_dict: Any = None
|
||||
completed_segments: int | None = None
|
||||
total_segments: int | None = None
|
||||
completed_segments: int | None = Field(default=None, exclude_if=lambda value: value is None)
|
||||
total_segments: int | None = Field(default=None, exclude_if=lambda value: value is None)
|
||||
|
||||
|
||||
class DatasetAndDocumentResponse(ResponseModel):
|
||||
@ -190,6 +135,14 @@ class DocumentDatasetListParam(BaseModel):
|
||||
fetch_val: str = Field("false", alias="fetch")
|
||||
|
||||
|
||||
class DocumentWithSegmentsListResponse(ResponseModel):
|
||||
data: list[DocumentWithSegmentsResponse]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
KnowledgeConfig,
|
||||
@ -200,13 +153,19 @@ register_schema_models(
|
||||
GenerateSummaryPayload,
|
||||
DocumentMetadataUpdatePayload,
|
||||
DocumentBatchDownloadZipPayload,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
SimpleResultMessageResponse,
|
||||
SimpleResultResponse,
|
||||
UrlResponse,
|
||||
DatasetResponse,
|
||||
DocumentMetadataResponse,
|
||||
DocumentResponse,
|
||||
DocumentWithSegmentsResponse,
|
||||
DatasetAndDocumentResponse,
|
||||
DocumentWithSegmentsListResponse,
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultMessageResponse, SimpleResultResponse, UrlResponse)
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
@ -312,7 +271,11 @@ class DatasetDocumentListApi(Resource):
|
||||
"status": "Filter documents by display status",
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Documents retrieved successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Documents retrieved successfully",
|
||||
console_ns.models[DocumentWithSegmentsListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -425,18 +388,15 @@ class DatasetDocumentListApi(Resource):
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
data = marshal(documents, document_with_segments_fields)
|
||||
else:
|
||||
data = marshal(documents, document_fields)
|
||||
response = {
|
||||
"data": data,
|
||||
"data": documents,
|
||||
"has_more": len(documents) == limit,
|
||||
"limit": limit,
|
||||
"total": paginated_documents.total,
|
||||
"page": page,
|
||||
}
|
||||
|
||||
return response
|
||||
return dump_response(DocumentWithSegmentsListResponse, response)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -482,9 +442,7 @@ class DatasetDocumentListApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
|
||||
return DatasetAndDocumentResponse.model_validate(
|
||||
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
return dump_response(DatasetAndDocumentResponse, {"dataset": dataset, "documents": documents, "batch": batch})
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -567,9 +525,7 @@ class DatasetInitApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
|
||||
return DatasetAndDocumentResponse.model_validate(
|
||||
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
return dump_response(DatasetAndDocumentResponse, {"dataset": dataset, "documents": documents, "batch": batch})
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
|
||||
@ -742,6 +698,9 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
|
||||
class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
@console_ns.response(
|
||||
200, "Indexing status retrieved successfully", console_ns.models[DocumentStatusListResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -784,9 +743,8 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
documents_status.append(marshal(document_dict, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
return data
|
||||
documents_status.append(document_dict)
|
||||
return dump_response(DocumentStatusListResponse, {"data": documents_status})
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
|
||||
@ -794,7 +752,9 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
@console_ns.doc("get_document_indexing_status")
|
||||
@console_ns.doc(description="Get document indexing status")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Indexing status retrieved successfully")
|
||||
@console_ns.response(
|
||||
200, "Indexing status retrieved successfully", console_ns.models[DocumentStatusResponse.__name__]
|
||||
)
|
||||
@console_ns.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -839,7 +799,7 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
return marshal(document_dict, document_status_fields)
|
||||
return dump_response(DocumentStatusResponse, document_dict)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
@ -1304,7 +1264,7 @@ class DocumentRenameApi(DocumentResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json")
|
||||
return dump_response(DocumentResponse, document)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import uuid
|
||||
from typing import Literal
|
||||
from typing import cast as type_cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import String, cast, func, or_, select
|
||||
from sqlalchemy import String, case, cast, func, literal, or_, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@ -13,7 +14,12 @@ import services
|
||||
from configs import dify_config
|
||||
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import (
|
||||
query_params_from_model,
|
||||
query_params_from_request,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import (
|
||||
@ -34,9 +40,17 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.base import ResponseModel
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from fields.segment_fields import (
|
||||
ChildChunkDetailResponse,
|
||||
ChildChunkListResponse,
|
||||
ChildChunkResponse,
|
||||
SegmentDetailResponse,
|
||||
SegmentResponse,
|
||||
segment_response_with_summary,
|
||||
segment_responses_with_summaries,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import escape_like_pattern
|
||||
from libs.helper import dump_response, escape_like_pattern
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
@ -44,20 +58,10 @@ from services.dataset_service import DatasetService, DocumentService, SegmentSer
|
||||
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
||||
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
||||
def _get_segment_with_summary(segment, dataset_id):
|
||||
"""Helper function to marshal segment and add summary information."""
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
|
||||
# Query summary for this segment (only enabled summaries)
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
class SegmentListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
status: list[str] = Field(default_factory=list)
|
||||
@ -67,6 +71,16 @@ class SegmentListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1)
|
||||
|
||||
|
||||
class SegmentIdListQuery(BaseModel):
|
||||
segment_id: list[str] = Field(default_factory=list, description="Segment IDs")
|
||||
|
||||
|
||||
class ChildChunkListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
keyword: str | None = None
|
||||
page: int = Field(default=1, ge=1)
|
||||
|
||||
|
||||
class SegmentCreatePayload(BaseModel):
|
||||
content: str
|
||||
answer: str | None = None
|
||||
@ -92,13 +106,35 @@ class SegmentBatchImportStatusResponse(ResponseModel):
|
||||
job_status: str
|
||||
|
||||
|
||||
class ConsoleSegmentListResponse(ResponseModel):
|
||||
data: list[SegmentResponse]
|
||||
limit: int
|
||||
total: int
|
||||
total_pages: int
|
||||
page: int
|
||||
|
||||
|
||||
class ChildChunkBatchUpdateResponse(ResponseModel):
|
||||
data: list[ChildChunkResponse]
|
||||
|
||||
|
||||
class ChildChunkBatchUpdatePayload(BaseModel):
|
||||
chunks: list[ChildChunkUpdateArgs]
|
||||
|
||||
|
||||
class SegmentDocParams:
|
||||
DATASET_DOCUMENT = {"dataset_id": "Dataset ID", "document_id": "Document ID"}
|
||||
DATASET_DOCUMENT_ACTION = {**DATASET_DOCUMENT, "action": "Action"}
|
||||
DATASET_DOCUMENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Segment ID"}
|
||||
DATASET_DOCUMENT_PARENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Parent segment ID"}
|
||||
DATASET_DOCUMENT_CHILD_CHUNK = {**DATASET_DOCUMENT_PARENT_SEGMENT, "child_chunk_id": "Child chunk ID"}
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SegmentListQuery,
|
||||
SegmentIdListQuery,
|
||||
ChildChunkListQuery,
|
||||
SegmentCreatePayload,
|
||||
SegmentUpdatePayload,
|
||||
BatchImportPayload,
|
||||
@ -107,11 +143,24 @@ register_schema_models(
|
||||
ChildChunkBatchUpdatePayload,
|
||||
ChildChunkUpdateArgs,
|
||||
)
|
||||
register_response_schema_models(console_ns, SegmentBatchImportStatusResponse, SimpleResultResponse)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
SegmentResponse,
|
||||
ConsoleSegmentListResponse,
|
||||
SegmentDetailResponse,
|
||||
ChildChunkDetailResponse,
|
||||
ChildChunkListResponse,
|
||||
ChildChunkBatchUpdateResponse,
|
||||
SegmentBatchImportStatusResponse,
|
||||
SimpleResultResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
class DatasetDocumentSegmentListApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
|
||||
@console_ns.doc(params=query_params_from_model(SegmentListQuery))
|
||||
@console_ns.response(200, "Segments retrieved successfully", console_ns.models[ConsoleSegmentListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -134,12 +183,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
args = SegmentListQuery.model_validate(
|
||||
{
|
||||
**request.args.to_dict(),
|
||||
"status": request.args.getlist("status"),
|
||||
}
|
||||
)
|
||||
args = query_params_from_request(SegmentListQuery, list_fields=("status",))
|
||||
|
||||
page = args.page
|
||||
limit = min(args.limit, 100)
|
||||
@ -169,12 +213,17 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
# Use database-specific methods for JSON array search
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
|
||||
# Guard with jsonb_typeof to avoid "cannot extract elements from a scalar" error
|
||||
# when keywords is null or a non-array JSON value.
|
||||
# Feed the set-returning function a JSON array in every row. Filtering in
|
||||
# the subquery is not enough because PostgreSQL can still evaluate the
|
||||
# SRF on scalar JSON before applying the predicate.
|
||||
keywords_jsonb = cast(DocumentSegment.keywords, JSONB)
|
||||
keywords_array = case(
|
||||
(func.jsonb_typeof(keywords_jsonb) == "array", keywords_jsonb),
|
||||
else_=cast(literal("[]"), JSONB),
|
||||
)
|
||||
keywords_condition = func.array_to_string(
|
||||
func.array(
|
||||
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
|
||||
.where(func.jsonb_typeof(cast(DocumentSegment.keywords, JSONB)) == "array")
|
||||
select(func.jsonb_array_elements_text(keywords_array))
|
||||
.correlate(DocumentSegment)
|
||||
.scalar_subquery()
|
||||
),
|
||||
@ -200,38 +249,30 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
|
||||
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
# Query summaries for all segments in this page (batch query for efficiency)
|
||||
segment_ids = [segment.id for segment in segments.items]
|
||||
summaries = {}
|
||||
segment_list = list(segments.items)
|
||||
segment_ids = [segment.id for segment in segment_list]
|
||||
summaries: dict[str, str | None] = {}
|
||||
if segment_ids:
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
summary_records = SummaryIndexService.get_segments_summaries(
|
||||
segment_ids=segment_ids, dataset_id=dataset_id_str
|
||||
)
|
||||
# Only include enabled summaries (already filtered by service)
|
||||
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
|
||||
|
||||
# Add summary to each segment
|
||||
segments_with_summary = []
|
||||
for segment in segments.items:
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
segments_with_summary.append(segment_dict)
|
||||
|
||||
response = {
|
||||
"data": segments_with_summary,
|
||||
"data": segment_responses_with_summaries(segment_list, summaries),
|
||||
"limit": limit,
|
||||
"total": segments.total,
|
||||
"total_pages": segments.pages,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
return dump_response(ConsoleSegmentListResponse, response), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
|
||||
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
|
||||
@console_ns.response(204, "Segments deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -263,6 +304,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
|
||||
class DatasetDocumentSegmentApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_ACTION)
|
||||
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -316,11 +359,12 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
SegmentService.update_segments_status(segment_ids, action, dataset, document)
|
||||
except Exception as e:
|
||||
raise InvalidActionError(str(e))
|
||||
return {"result": "success"}, 200
|
||||
return dump_response(SimpleResultResponse, {"result": "success"}), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
|
||||
class DatasetDocumentSegmentAddApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -328,6 +372,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
|
||||
@console_ns.response(200, "Segment created successfully", console_ns.models[SegmentDetailResponse.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -367,18 +412,25 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.create_segment(payload_dict, document, dataset)
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
|
||||
segment = type_cast(DocumentSegment, SegmentService.create_segment(payload_dict, document, dataset))
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
|
||||
response = {
|
||||
"data": segment_response_with_summary(segment, summary.summary_content if summary else None),
|
||||
"doc_form": document.doc_form,
|
||||
}
|
||||
return dump_response(SegmentDetailResponse, response), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Segment updated successfully", console_ns.models[SegmentDetailResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -435,12 +487,18 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
|
||||
)
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
|
||||
response = {
|
||||
"data": segment_response_with_summary(segment, summary.summary_content if summary else None),
|
||||
"doc_form": document.doc_form,
|
||||
}
|
||||
return dump_response(SegmentDetailResponse, response), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
|
||||
@console_ns.response(204, "Segment deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@ -518,11 +576,11 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
try:
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
indexing_cache_key = f"segment_batch_import_{str(job_id)}"
|
||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, "waiting")
|
||||
batch_create_segment_to_index_task.delay(
|
||||
str(job_id),
|
||||
job_id,
|
||||
upload_file_id,
|
||||
dataset_id_str,
|
||||
document_id_str,
|
||||
@ -531,7 +589,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 500
|
||||
return {"job_id": job_id, "job_status": "waiting"}, 200
|
||||
return dump_response(SegmentBatchImportStatusResponse, {"job_id": job_id, "job_status": "waiting"}), 200
|
||||
|
||||
@console_ns.response(200, "Batch import status", console_ns.models[SegmentBatchImportStatusResponse.__name__])
|
||||
@setup_required
|
||||
@ -546,11 +604,13 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
if cache_result is None:
|
||||
raise ValueError("The job does not exist.")
|
||||
|
||||
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
|
||||
response = {"job_id": job_id, "job_status": cache_result.decode()}
|
||||
return dump_response(SegmentBatchImportStatusResponse, response), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks")
|
||||
class ChildChunkAddApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -558,6 +618,7 @@ class ChildChunkAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
|
||||
@console_ns.response(200, "Child chunk created successfully", console_ns.models[ChildChunkDetailResponse.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -608,8 +669,11 @@ class ChildChunkAddApi(Resource):
|
||||
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
|
||||
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
|
||||
@console_ns.doc(params=query_params_from_model(ChildChunkListQuery))
|
||||
@console_ns.response(200, "Child chunks retrieved successfully", console_ns.models[ChildChunkListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -637,13 +701,7 @@ class ChildChunkAddApi(Resource):
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
args = SegmentListQuery.model_validate(
|
||||
{
|
||||
"limit": request.args.get("limit", default=20, type=int),
|
||||
"keyword": request.args.get("keyword"),
|
||||
"page": request.args.get("page", default=1, type=int),
|
||||
}
|
||||
)
|
||||
args = query_params_from_request(ChildChunkListQuery, use_defaults_for_malformed_ints=True)
|
||||
|
||||
page = args.page
|
||||
limit = min(args.limit, 100)
|
||||
@ -652,19 +710,27 @@ class ChildChunkAddApi(Resource):
|
||||
child_chunks = SegmentService.get_child_chunks(
|
||||
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
|
||||
)
|
||||
return {
|
||||
"data": marshal(child_chunks.items, child_chunk_fields),
|
||||
response = {
|
||||
"data": child_chunks.items,
|
||||
"total": child_chunks.total,
|
||||
"total_pages": child_chunks.pages,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
}, 200
|
||||
}
|
||||
return dump_response(ChildChunkListResponse, response), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Child chunks updated successfully",
|
||||
console_ns.models[ChildChunkBatchUpdateResponse.__name__],
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ChildChunkBatchUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -702,7 +768,7 @@ class ChildChunkAddApi(Resource):
|
||||
child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
|
||||
return dump_response(ChildChunkBatchUpdateResponse, {"data": child_chunks}), 200
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
@ -713,6 +779,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
|
||||
@console_ns.response(204, "Child chunk deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@ -743,7 +810,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id_str),
|
||||
ChildChunk.id == child_chunk_id_str,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id_str,
|
||||
@ -770,7 +837,9 @@ class ChildChunkUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
|
||||
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Child chunk updated successfully", console_ns.models[ChildChunkDetailResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -800,7 +869,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id_str),
|
||||
ChildChunk.id == child_chunk_id_str,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id_str,
|
||||
@ -822,4 +891,4 @@ class ChildChunkUpdateApi(Resource):
|
||||
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
|
||||
|
||||
@ -7,14 +7,20 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from fields.dataset_fields import (
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
@ -43,8 +49,8 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@enterprise_license_required
|
||||
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -80,8 +86,8 @@ class DatasetMetadataApi(Resource):
|
||||
@enterprise_license_required
|
||||
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, metadata_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def patch(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
|
||||
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
|
||||
name = payload.name
|
||||
|
||||
@ -100,8 +106,8 @@ class DatasetMetadataApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(204, "Metadata deleted successfully")
|
||||
def delete(self, dataset_id: UUID, metadata_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -137,8 +143,8 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(204, "Action completed successfully")
|
||||
def post(self, dataset_id: UUID, action: Literal["enable", "disable"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable"]):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -165,8 +171,8 @@ class DocumentMetadataEditApi(Resource):
|
||||
204,
|
||||
"Documents metadata updated successfully",
|
||||
)
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
|
||||
@ -198,12 +198,13 @@ class DatasourceAuth(Resource):
|
||||
def get(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasources = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
user=user,
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
@ -11,10 +11,13 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
@ -35,9 +38,10 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
@ -79,10 +83,10 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
|
||||
@ -10,6 +10,7 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.rag_pipeline_fields import (
|
||||
@ -17,7 +18,8 @@ from fields.rag_pipeline_fields import (
|
||||
pipeline_import_check_dependencies_fields,
|
||||
pipeline_import_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.entities.dsl_entities import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
@ -62,9 +64,9 @@ class RagPipelineImportApi(Resource):
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_model)
|
||||
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Use a plain Session so that caught exceptions inside the service
|
||||
@ -105,9 +107,8 @@ class RagPipelineImportConfirmApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_model)
|
||||
def post(self, import_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, import_id: str):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
account = current_user
|
||||
|
||||
@ -18,6 +18,7 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user_id
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
@ -135,20 +136,18 @@ class CompletionApi(InstalledAppResource):
|
||||
)
|
||||
class CompletionStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
@ -215,7 +214,8 @@ class ChatApi(InstalledAppResource):
|
||||
)
|
||||
class ChatStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
@ -223,13 +223,10 @@ class ChatStopApi(InstalledAppResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
|
||||
@ -12,14 +12,19 @@ from controllers.common.fields import SimpleMessageResponse, SimpleResultMessage
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from libs.login import login_required
|
||||
from models import Account, App, InstalledApp, RecommendedApp
|
||||
from models.model import IconType
|
||||
from services.account_service import TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -131,9 +136,10 @@ class InstalledAppsListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
|
||||
def get(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if query.app_id:
|
||||
installed_apps = db.session.scalars(
|
||||
@ -212,7 +218,8 @@ class InstalledAppsListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleMessageResponse.__name__])
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
recommended_app = db.session.scalar(
|
||||
@ -221,8 +228,6 @@ class InstalledAppsListApi(Resource):
|
||||
if recommended_app is None:
|
||||
raise NotFound("Recommended app not found")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
app = db.session.get(App, payload.app_id)
|
||||
|
||||
if app is None:
|
||||
@ -262,8 +267,8 @@ class InstalledAppApi(InstalledAppResource):
|
||||
"""
|
||||
|
||||
@console_ns.response(204, "App uninstalled successfully")
|
||||
def delete(self, installed_app: InstalledApp):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, installed_app: InstalledApp):
|
||||
if installed_app.app_owner_tenant_id == current_tenant_id:
|
||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||
|
||||
|
||||
@ -2,13 +2,36 @@ from flask_restx import Resource
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_user, login_required
|
||||
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
|
||||
from services.feature_service import (
|
||||
FeatureModel,
|
||||
FeatureService,
|
||||
LimitationModel,
|
||||
SystemFeatureModel,
|
||||
)
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required, with_current_tenant_id
|
||||
|
||||
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
|
||||
|
||||
class TrialModelsResponse(ResponseModel):
|
||||
trial_models: list[str]
|
||||
|
||||
|
||||
class AppDslVersionResponse(ResponseModel):
|
||||
app_dsl_version: str
|
||||
|
||||
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AppDslVersionResponse,
|
||||
FeatureModel,
|
||||
LimitationModel,
|
||||
SystemFeatureModel,
|
||||
TrialModelsResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/features")
|
||||
@ -54,6 +77,43 @@ class FeatureVectorSpaceApi(Resource):
|
||||
return FeatureService.get_vector_space(current_tenant_id).model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/trial-models")
|
||||
class TrialModelsApi(Resource):
|
||||
@console_ns.doc("get_trial_models")
|
||||
@console_ns.doc(description="Get hosted trial model provider configuration")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.models[TrialModelsResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""Get hosted trial model provider configuration for model-provider pages."""
|
||||
return dump_response(
|
||||
TrialModelsResponse,
|
||||
{"trial_models": FeatureService.get_trial_models()},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/app-dsl-version")
|
||||
class AppDslVersionApi(Resource):
|
||||
@console_ns.doc("get_app_dsl_version")
|
||||
@console_ns.doc(description="Get current app DSL version")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.models[AppDslVersionResponse.__name__],
|
||||
)
|
||||
def get(self):
|
||||
"""Get current app DSL version for workflow clipboard compatibility."""
|
||||
return dump_response(
|
||||
AppDslVersionResponse,
|
||||
{"app_dsl_version": FeatureService.get_app_dsl_version()},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/system-features")
|
||||
class SystemFeatureApi(Resource):
|
||||
@console_ns.doc("get_system_features")
|
||||
|
||||
@ -12,8 +12,15 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
model_validate,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
@ -22,8 +29,8 @@ from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from libs.login import login_required
|
||||
from models import Account, App
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
@ -33,6 +40,8 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_models(console_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
@ -45,9 +54,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
"""Console API for getting human input form definition."""
|
||||
|
||||
@staticmethod
|
||||
def _ensure_console_access(form: Form):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
def _ensure_console_access(form: Form, current_tenant_id: str) -> None:
|
||||
"""Ensure a console form token resolves only inside the current tenant."""
|
||||
if form.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("App not found")
|
||||
|
||||
@ -59,7 +67,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, form_token: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, form_token: str):
|
||||
"""
|
||||
Get human input form definition by form token.
|
||||
|
||||
@ -70,13 +79,23 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
self._ensure_console_access(form, current_tenant_id)
|
||||
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def post(self, form_token: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
@model_validate(HumanInputFormSubmitPayload)
|
||||
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
def post(
|
||||
self,
|
||||
payload: HumanInputFormSubmitPayload,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
form_token: str,
|
||||
):
|
||||
"""
|
||||
Submit human input form by form token.
|
||||
|
||||
@ -90,15 +109,12 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
self._ensure_console_access(form, current_tenant_id)
|
||||
self._ensure_console_recipient_type(form)
|
||||
recipient_type = form.recipient_type
|
||||
# The type checker is not smart enought to validate the following invariant.
|
||||
@ -122,7 +138,9 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def get(self, workflow_run_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow execution events stream after resume.
|
||||
|
||||
@ -130,8 +148,6 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
|
||||
Returns Server-Sent Events stream.
|
||||
"""
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
|
||||
@ -13,7 +13,7 @@ from controllers.common.errors import (
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import with_current_user
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
@ -36,9 +36,9 @@ class GetRemoteFileInfo(Resource):
|
||||
@login_required
|
||||
def get(self, url: str):
|
||||
decoded_url = helpers.decode_remote_url(url, request.query_string)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
resp = remote_fetcher.make_request("HEAD", decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp = remote_fetcher.make_request("GET", decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
return RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
@ -58,9 +58,9 @@ class RemoteFileUpload(Resource):
|
||||
|
||||
# Try to fetch remote file metadata/content first
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
resp = remote_fetcher.make_request("HEAD", url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
resp = remote_fetcher.make_request("GET", url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# Normalize into a user-friendly error message expected by tests
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
@ -74,7 +74,7 @@ class RemoteFileUpload(Resource):
|
||||
raise FileTooLargeError()
|
||||
|
||||
# Load content if needed
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
content = resp.content if resp.request.method == "GET" else remote_fetcher.make_request("GET", url).content
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
|
||||
@ -18,7 +18,7 @@ from controllers.common.fields import (
|
||||
SimpleResultResponse,
|
||||
VerificationTokenResponse,
|
||||
)
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
@ -48,7 +48,7 @@ from fields.base import ResponseModel
|
||||
from fields.member_fields import Account as AccountResponse
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone, to_timestamp
|
||||
from libs.helper import EmailStr, dump_response, extract_remote_ip, timezone, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus, InvitationCodeStatus
|
||||
@ -173,7 +173,6 @@ class CheckEmailUniquePayload(BaseModel):
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AccountResponse,
|
||||
AccountInitPayload,
|
||||
AccountNamePayload,
|
||||
AccountAvatarPayload,
|
||||
@ -245,6 +244,7 @@ register_schema_models(
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AccountResponse,
|
||||
AvatarUrlResponse,
|
||||
SimpleResultDataResponse,
|
||||
SimpleResultResponse,
|
||||
@ -329,9 +329,9 @@ class AccountNameApi(Resource):
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AccountAvatarQuery.__name__])
|
||||
@console_ns.doc("get_account_avatar")
|
||||
@console_ns.doc(description="Get account avatar url")
|
||||
@console_ns.doc(params=query_params_from_model(AccountAvatarQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[AvatarUrlResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -342,7 +342,7 @@ class AccountAvatarApi(Resource):
|
||||
avatar = args.avatar
|
||||
|
||||
if avatar.startswith(("http://", "https://")):
|
||||
return {"avatar_url": avatar}
|
||||
return dump_response(AvatarUrlResponse, {"avatar_url": avatar})
|
||||
|
||||
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == avatar).limit(1))
|
||||
if upload_file is None:
|
||||
@ -355,7 +355,7 @@ class AccountAvatarApi(Resource):
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {"avatar_url": avatar_url}
|
||||
return dump_response(AvatarUrlResponse, {"avatar_url": avatar_url})
|
||||
|
||||
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
|
||||
@setup_required
|
||||
|
||||
@ -1,9 +1,15 @@
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
@ -19,14 +25,10 @@ class AgentProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
return jsonable_encoder(AgentService.list_agent_providers(current_user.id, current_tenant_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
|
||||
@ -42,6 +44,7 @@ class AgentProviderApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, provider_name: str):
|
||||
return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))
|
||||
|
||||
@ -4,11 +4,16 @@ from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import TenantAccountRole
|
||||
from libs.login import login_required
|
||||
from models import Account, TenantAccountRole
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
@ -29,8 +34,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, provider: str):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
@ -72,8 +78,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str, config_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, provider: str, config_id: str):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
@ -25,12 +25,13 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_allow_transfer_owner,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.member_fields import AccountWithRole, AccountWithRoleList
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
@ -136,8 +137,8 @@ class MemberListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
@ -154,7 +155,8 @@ class MemberInviteEmailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = MemberInvitePayload.model_validate(payload)
|
||||
|
||||
@ -163,7 +165,6 @@ class MemberInviteEmailApi(Resource):
|
||||
interface_language = args.language
|
||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
current_user, _ = current_account_with_tenant()
|
||||
inviter = current_user
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
@ -223,8 +224,8 @@ class MemberCancelInviteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, member_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, member_id: UUID):
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
member = db.session.get(Account, str(member_id))
|
||||
@ -256,14 +257,14 @@ class MemberUpdateRoleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, member_id: UUID):
|
||||
@with_current_user
|
||||
def put(self, current_user: Account, member_id: UUID):
|
||||
payload = console_ns.payload or {}
|
||||
args = MemberRoleUpdatePayload.model_validate(payload)
|
||||
new_role = args.role
|
||||
|
||||
if not TenantAccountRole.is_valid_role(new_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not _is_role_enabled(new_role, current_user.current_tenant.id):
|
||||
@ -297,8 +298,8 @@ class DatasetOperatorMemberListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
@ -317,13 +318,13 @@ class SendOwnerTransferEmailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferEmailPayload.model_validate(payload)
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# check if the current user is the owner of the workspace
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
@ -355,11 +356,11 @@ class OwnerTransferCheckApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferCheckPayload.model_validate(payload)
|
||||
# check if the current user is the owner of the workspace
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
@ -399,12 +400,12 @@ class OwnerTransfer(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self, member_id: UUID):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, member_id: UUID):
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferPayload.model_validate(payload)
|
||||
|
||||
# check if the current user is the owner of the workspace
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
|
||||
@ -8,12 +8,19 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
@ -95,10 +102,8 @@ class ModelProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
args = ParserModelList.model_validate(payload)
|
||||
|
||||
@ -114,9 +119,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
# if credential_id is not provided, return current used credential
|
||||
payload = request.args.to_dict(flat=True)
|
||||
args = ParserCredentialId.model_validate(payload)
|
||||
@ -133,8 +137,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialCreate.model_validate(payload)
|
||||
|
||||
@ -157,9 +161,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialUpdate.model_validate(payload)
|
||||
|
||||
@ -184,8 +187,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialDelete.model_validate(payload)
|
||||
|
||||
@ -205,8 +208,8 @@ class ModelProviderCredentialSwitchApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialSwitch.model_validate(payload)
|
||||
|
||||
@ -225,8 +228,8 @@ class ModelProviderValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialValidate.model_validate(payload)
|
||||
|
||||
@ -280,11 +283,8 @@ class PreferredProviderTypeUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserPreferredProviderType.model_validate(payload)
|
||||
|
||||
@ -301,10 +301,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, provider: str):
|
||||
if provider != "anthropic":
|
||||
raise ValueError(f"provider name {provider} is invalid")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
data = BillingService.get_model_provider_payment_link(
|
||||
provider_name=provider,
|
||||
|
||||
@ -13,12 +13,14 @@ from controllers.console.wraps import (
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
@ -193,7 +195,7 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider):
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
|
||||
|
||||
@ -269,8 +271,9 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
def get(self, tenant_id: str, user: Account, provider: str):
|
||||
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -292,9 +295,13 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
)
|
||||
|
||||
if args.config_from == "predefined-model":
|
||||
# Only the predefined-model branch needs visibility filtering by user.
|
||||
# The account is injected once by the handler and only passed into the
|
||||
# service branch that needs user-scoped credential visibility.
|
||||
available_credentials = model_provider_service.get_provider_available_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
available_credentials = model_provider_service.get_provider_model_available_credentials(
|
||||
|
||||
@ -69,6 +69,7 @@ class BuiltinToolAddPayload(BaseModel):
|
||||
credentials: dict[str, Any]
|
||||
name: str | None = Field(default=None, max_length=30)
|
||||
type: CredentialType
|
||||
visibility: str | None = None
|
||||
|
||||
|
||||
class BuiltinToolUpdatePayload(BaseModel):
|
||||
@ -277,7 +278,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(
|
||||
@ -293,7 +294,7 @@ class ToolBuiltinProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||
@ -306,7 +307,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {})
|
||||
@ -324,7 +325,7 @@ class ToolBuiltinProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
@ -338,6 +339,7 @@ class ToolBuiltinProviderAddApi(Resource):
|
||||
credentials=payload.credentials,
|
||||
name=payload.name,
|
||||
api_type=CredentialType.of(payload.type),
|
||||
visibility=payload.visibility,
|
||||
)
|
||||
|
||||
|
||||
@ -348,7 +350,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
user_id = user.id
|
||||
|
||||
@ -370,13 +372,20 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
def get(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
# Optional list of credential IDs to include even if visibility would hide them
|
||||
# (used when a workflow/agent node still references another member's only_me credential).
|
||||
include_credential_ids = request.args.getlist("include_credential_ids") or [
|
||||
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
|
||||
]
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
user=user,
|
||||
include_credential_ids=include_credential_ids or None,
|
||||
)
|
||||
)
|
||||
|
||||
@ -384,7 +393,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
||||
class ToolBuiltinProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
|
||||
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
@ -784,7 +793,7 @@ class ToolPluginOAuthApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
tool_provider = ToolProviderID(provider)
|
||||
plugin_id = tool_provider.plugin_id
|
||||
provider_name = tool_provider.provider_name
|
||||
@ -822,7 +831,7 @@ class ToolPluginOAuthApi(Resource):
|
||||
@console_ns.route("/oauth/plugin/<path:provider>/tool/callback")
|
||||
class ToolOAuthCallback(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
@ -859,7 +868,7 @@ class ToolOAuthCallback(Resource):
|
||||
if not credentials:
|
||||
raise Exception("the plugin credentials failed")
|
||||
|
||||
# add credentials to database
|
||||
# add credentials to database — OAuth tokens default to only_me since they're personal
|
||||
BuiltinToolManageService.add_builtin_tool_provider(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
@ -867,6 +876,7 @@ class ToolOAuthCallback(Resource):
|
||||
credentials=dict(credentials),
|
||||
expires_at=expires_at,
|
||||
api_type=CredentialType.OAUTH2,
|
||||
visibility="only_me",
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
@ -878,7 +888,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
@ -910,7 +920,7 @@ class ToolOAuthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||
@ -919,7 +929,7 @@ class ToolOAuthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||
@ -931,7 +941,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
|
||||
@ -945,13 +955,18 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
def get(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
include_credential_ids = request.args.getlist("include_credential_ids") or [
|
||||
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
|
||||
]
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
user=user,
|
||||
include_credential_ids=include_credential_ids or None,
|
||||
)
|
||||
)
|
||||
|
||||
@ -1151,7 +1166,7 @@ class ToolMCPDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
def get(self, provider_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
@ -1180,7 +1195,7 @@ class ToolMCPUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
def get(self, provider_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
|
||||
@ -77,7 +77,7 @@ class TriggerProviderIconApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
@ -103,7 +103,7 @@ class TriggerProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
"""Get info for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
@ -119,15 +119,18 @@ class TriggerSubscriptionListApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
"""List all trigger subscriptions for the current tenant's provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.list_trigger_provider_subscriptions(
|
||||
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
@ -146,7 +149,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
"""Add a new subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -175,7 +178,7 @@ class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
def get(self, provider: str, subscription_builder_id: str):
|
||||
"""Get a subscription instance for a trigger provider"""
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
|
||||
@ -191,7 +194,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
def post(self, provider: str, subscription_builder_id: str):
|
||||
"""Verify and update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -223,7 +226,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
def post(self, provider: str, subscription_builder_id: str):
|
||||
"""Update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
@ -257,7 +260,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
def get(self, provider: str, subscription_builder_id: str):
|
||||
"""Get the request logs for a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
@ -280,7 +283,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
def post(self, provider: str, subscription_builder_id: str):
|
||||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -404,7 +407,7 @@ class TriggerOAuthAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
"""Initiate OAuth authorization flow for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
@ -486,7 +489,7 @@ class TriggerOAuthAuthorizeApi(Resource):
|
||||
@console_ns.route("/oauth/plugin/<path:provider>/trigger/callback")
|
||||
class TriggerOAuthCallbackApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
"""Handle OAuth callback for trigger provider"""
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
@ -554,7 +557,7 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
"""Get OAuth client configuration for a provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -600,7 +603,7 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
"""Configure custom OAuth client for a provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -626,7 +629,7 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
def delete(self, provider: str):
|
||||
"""Remove custom OAuth client configuration"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -654,7 +657,7 @@ class TriggerSubscriptionVerifyApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_id):
|
||||
def post(self, provider: str, subscription_id: str):
|
||||
"""Verify credentials for an existing subscription (edit mode only)"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
@ -29,7 +29,7 @@ from controllers.console.wraps import (
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import TimestampField, to_timestamp
|
||||
from libs.helper import TimestampField, dump_response, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
@ -56,6 +56,11 @@ class WorkspaceCustomConfigPayload(BaseModel):
|
||||
replace_webapp_logo: str | None = None
|
||||
|
||||
|
||||
class WorkspaceCustomConfigResponse(ResponseModel):
|
||||
remove_webapp_brand: bool | None = None
|
||||
replace_webapp_logo: str | None = None
|
||||
|
||||
|
||||
class WorkspaceInfoPayload(BaseModel):
|
||||
name: str
|
||||
|
||||
@ -69,7 +74,7 @@ class TenantInfoResponse(ResponseModel):
|
||||
role: str | None = None
|
||||
in_trial: bool | None = None
|
||||
trial_end_reason: str | None = None
|
||||
custom_config: dict | None = None
|
||||
custom_config: WorkspaceCustomConfigResponse | None = None
|
||||
trial_credits: int | None = None
|
||||
trial_credits_used: int | None = None
|
||||
next_credit_reset_date: int | None = None
|
||||
@ -101,9 +106,13 @@ register_schema_models(
|
||||
SwitchWorkspacePayload,
|
||||
WorkspaceCustomConfigPayload,
|
||||
WorkspaceInfoPayload,
|
||||
TenantInfoResponse,
|
||||
)
|
||||
register_response_schema_models(console_ns, WorkspacePermissionResponse)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
TenantInfoResponse,
|
||||
WorkspaceCustomConfigResponse,
|
||||
WorkspacePermissionResponse,
|
||||
)
|
||||
|
||||
provider_fields = {
|
||||
"provider_name": fields.String,
|
||||
@ -238,13 +247,7 @@ class TenantApi(Resource):
|
||||
else:
|
||||
raise Unauthorized("workspace is archived")
|
||||
|
||||
return (
|
||||
TenantInfoResponse.model_validate(
|
||||
WorkspaceService.get_tenant_info(tenant),
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
return dump_response(TenantInfoResponse, WorkspaceService.get_tenant_info(tenant)), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/switch")
|
||||
|
||||
@ -7,7 +7,9 @@ from functools import wraps
|
||||
from typing import Concatenate
|
||||
|
||||
from flask import abort, request
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
|
||||
@ -512,9 +514,79 @@ def with_current_tenant_id[T, **P, R](
|
||||
def with_current_user[T, **P, R](
|
||||
view: Callable[Concatenate[T, Account, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
"""Inject the current authenticated Account into the handler as the first argument after self.
|
||||
|
||||
Usage::
|
||||
|
||||
class MyResource(Resource):
|
||||
@login_required
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
...
|
||||
"""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
return view(self, current_user, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def with_current_user_id[T, **P, R](
|
||||
view: Callable[Concatenate[T, str, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
"""Inject the current authenticated user's ID (as a string) into the handler.
|
||||
|
||||
Use this when the handler only needs the user ID and not the full Account object.
|
||||
|
||||
Usage::
|
||||
|
||||
class MyResource(Resource):
|
||||
@login_required
|
||||
@with_current_user_id
|
||||
def get(self, current_user_id: str):
|
||||
...
|
||||
"""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
return view(self, str(current_user.id), *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def model_validate[T, M: BaseModel, **P, R](
|
||||
model: type[M],
|
||||
) -> Callable[
|
||||
[Callable[Concatenate[T, M, P], R]],
|
||||
Callable[Concatenate[T, P], R],
|
||||
]:
|
||||
"""Validate request data and inject the model instance as the first arg after self.
|
||||
|
||||
Source is determined by HTTP method:
|
||||
GET/DELETE -> request.args
|
||||
POST/PUT/PATCH -> JSON body
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
view: Callable[Concatenate[T, M, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if request.method in ("GET", "DELETE"):
|
||||
raw = request.args.to_dict(flat=True)
|
||||
else:
|
||||
raw = request.get_json(silent=True) or {}
|
||||
|
||||
try:
|
||||
validated = model.model_validate(raw)
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
return view(self, validated, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@ -45,6 +45,15 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
# Try id first (preserves the original "explicit end-user
|
||||
# id → that specific user" semantics for callers that pass
|
||||
# a known EndUser.id). Fall back to session_id so daemon-
|
||||
# supplied session UUIDs dedup against the row created on
|
||||
# the first Reverse Invocation call — without this, an
|
||||
# id-only lookup never matched (create writes user_id to
|
||||
# session_id, id is auto-generated) and a fresh EndUser
|
||||
# was created per call, breaking multi-turn chat
|
||||
# continuation (see #36736).
|
||||
user_model = session.scalar(
|
||||
select(EndUser)
|
||||
.where(
|
||||
@ -53,6 +62,15 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if user_model is None:
|
||||
user_model = session.scalar(
|
||||
select(EndUser)
|
||||
.where(
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
|
||||
@ -7,7 +7,7 @@ from hmac import new as hmac_new
|
||||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from core.db.session_factory import session_factory
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
@ -44,6 +44,8 @@ def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
|
||||
|
||||
def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Inject an EndUser for valid inner API HMAC auth, otherwise pass the request through unchanged."""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.INNER_API:
|
||||
@ -72,9 +74,9 @@ def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P,
|
||||
if signature_base64 != token:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
kwargs["user"] = db.session.get(EndUser, user_id)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
with session_factory.create_session() as session:
|
||||
kwargs["user"] = session.get(EndUser, user_id)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@ -147,7 +147,7 @@ class AppDescribeApi(AppReadResource):
|
||||
class AppListApi(Resource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
|
||||
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@auth_router.guard_workspace(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
try:
|
||||
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.conditions import (
|
||||
EDITION_CE,
|
||||
EDITION_EE,
|
||||
HAS_ALLOWED_ROLES,
|
||||
LOADED_APP_IS_PRIVATE,
|
||||
PATH_HAS_APP_ID,
|
||||
WEBAPP_AUTH_ENABLED,
|
||||
WORKSPACE_MEMBERSHIP_REQUIRED,
|
||||
WORKSPACE_SCOPED,
|
||||
)
|
||||
from controllers.openapi.auth.data import Edition
|
||||
from controllers.openapi.auth.flow import When
|
||||
@ -15,14 +17,18 @@ from controllers.openapi.auth.prepare import (
|
||||
load_app,
|
||||
load_app_access_mode,
|
||||
load_tenant,
|
||||
load_tenant_from_request,
|
||||
load_workspace_role,
|
||||
resolve_external_user,
|
||||
)
|
||||
from controllers.openapi.auth.verify import (
|
||||
check_acl,
|
||||
check_app_access,
|
||||
check_membership,
|
||||
check_app_api_enabled,
|
||||
check_private_app_permission,
|
||||
check_scope,
|
||||
check_workspace_member,
|
||||
check_workspace_mismatch,
|
||||
check_workspace_role,
|
||||
)
|
||||
from libs.oauth_bearer import TokenType
|
||||
|
||||
@ -30,13 +36,17 @@ account_pipeline = AuthPipeline(
|
||||
prepare=[
|
||||
When(PATH_HAS_APP_ID, then=load_app),
|
||||
When(PATH_HAS_APP_ID, then=load_tenant),
|
||||
load_account, # all tokens here are account tokens
|
||||
When(WORKSPACE_MEMBERSHIP_REQUIRED, then=load_tenant_from_request),
|
||||
load_account,
|
||||
When(WORKSPACE_SCOPED, then=load_workspace_role),
|
||||
When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
|
||||
check_scope,
|
||||
When(EDITION_CE & PATH_HAS_APP_ID, then=check_membership),
|
||||
When(EDITION_EE & PATH_HAS_APP_ID & ~WEBAPP_AUTH_ENABLED, then=check_app_access),
|
||||
When(WORKSPACE_SCOPED, then=check_workspace_member),
|
||||
When(PATH_HAS_APP_ID, then=check_workspace_mismatch),
|
||||
When(HAS_ALLOWED_ROLES, then=check_workspace_role),
|
||||
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
],
|
||||
@ -50,6 +60,7 @@ external_sso_pipeline = AuthPipeline(
|
||||
When(PATH_HAS_APP_ID, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
|
||||
check_scope,
|
||||
When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
|
||||
@ -50,4 +50,11 @@ EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
|
||||
|
||||
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
|
||||
|
||||
WORKSPACE_MEMBERSHIP_REQUIRED = request_cond(lambda ctx: ctx.workspace_membership)
|
||||
HAS_ALLOWED_ROLES = request_cond(lambda ctx: ctx.allowed_roles is not None)
|
||||
|
||||
# Caller must belong to the resolved tenant: either an app-scoped path (tenant
|
||||
# from the app) or an explicit workspace-membership path (tenant from request).
|
||||
WORKSPACE_SCOPED = PATH_HAS_APP_ID | WORKSPACE_MEMBERSHIP_REQUIRED
|
||||
|
||||
LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE)
|
||||
|
||||
@ -9,7 +9,7 @@ from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from configs import dify_config
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import Account, Tenant
|
||||
from models.account import Account, Tenant, TenantAccountRole
|
||||
from models.model import App, EndUser
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
@ -41,6 +41,8 @@ class RequestContext(BaseModel):
|
||||
token_type: TokenType
|
||||
scope: Scope | None = None
|
||||
path_params: dict[str, str]
|
||||
workspace_membership: bool = False
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None
|
||||
|
||||
|
||||
class AuthData(BaseModel):
|
||||
@ -56,10 +58,14 @@ class AuthData(BaseModel):
|
||||
external_identity: ExternalIdentity | None = None
|
||||
path_params: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None
|
||||
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
app_access_mode: WebAppAccessMode | None = None
|
||||
|
||||
tenant_role: TenantAccountRole | None = None
|
||||
|
||||
caller: Account | EndUser | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
|
||||
|
||||
@ -34,6 +34,7 @@ from libs.oauth_bearer import (
|
||||
reset_auth_ctx,
|
||||
set_auth_ctx,
|
||||
)
|
||||
from models.account import TenantAccountRole
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
||||
|
||||
@ -56,11 +57,15 @@ class AuthPipeline:
|
||||
view: Callable,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Any:
|
||||
req_ctx = RequestContext(
|
||||
token_type=identity.token_type,
|
||||
scope=scope,
|
||||
path_params=dict(request.view_args or {}),
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
data = AuthData(
|
||||
@ -71,6 +76,7 @@ class AuthPipeline:
|
||||
scopes=frozenset(identity.scopes),
|
||||
tenants=dict(identity.verified_tenants),
|
||||
required_scope=scope,
|
||||
allowed_roles=allowed_roles,
|
||||
path_params=dict(req_ctx.path_params),
|
||||
external_identity=(
|
||||
ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer)
|
||||
@ -121,6 +127,41 @@ class PipelineRouter:
|
||||
scope: Scope | None = None,
|
||||
allowed_token_types: frozenset[TokenType] | None = None,
|
||||
edition: frozenset[Edition] | None = None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Callable:
|
||||
return self._make_decorator(
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
def guard_workspace(
|
||||
self,
|
||||
*,
|
||||
scope: Scope | None = None,
|
||||
allowed_token_types: frozenset[TokenType] | None = None,
|
||||
edition: frozenset[Edition] | None = None,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Callable:
|
||||
return self._make_decorator(
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
workspace_membership=True,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
def _make_decorator(
|
||||
self,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
allowed_token_types: frozenset[TokenType] | None,
|
||||
edition: frozenset[Edition] | None,
|
||||
workspace_membership: bool,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None,
|
||||
) -> Callable:
|
||||
def decorator(view: Callable) -> Callable:
|
||||
@wraps(view)
|
||||
@ -132,6 +173,8 @@ class PipelineRouter:
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
return decorated
|
||||
@ -147,6 +190,8 @@ class PipelineRouter:
|
||||
scope: Scope | None,
|
||||
allowed_token_types: frozenset[TokenType] | None,
|
||||
edition: frozenset[Edition] | None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Any:
|
||||
# 404 not 403 — this edition doesn't expose the feature at all
|
||||
if edition is not None and current_edition() not in edition:
|
||||
@ -182,7 +227,15 @@ class PipelineRouter:
|
||||
if not license_checked and Edition.EE in route.required_edition:
|
||||
_check_license()
|
||||
|
||||
return route.pipeline._run(identity, args, kwargs, view, scope=scope)
|
||||
return route.pipeline._run(
|
||||
identity,
|
||||
args,
|
||||
kwargs,
|
||||
view,
|
||||
scope=scope,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
|
||||
def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool:
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
@ -13,16 +16,18 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppAcce
|
||||
|
||||
|
||||
def load_app(data: AuthData) -> None:
|
||||
if data.app is not None:
|
||||
return
|
||||
app_id = data.path_params["app_id"]
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
data.app = app
|
||||
|
||||
|
||||
def load_tenant(data: AuthData) -> None:
|
||||
if data.tenant is not None:
|
||||
return
|
||||
if data.app is None:
|
||||
raise InternalServerError("pipeline_invariant_violated: app not loaded before load_tenant")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, str(data.app.tenant_id))
|
||||
@ -31,7 +36,25 @@ def load_tenant(data: AuthData) -> None:
|
||||
data.tenant = tenant
|
||||
|
||||
|
||||
def load_tenant_from_request(data: AuthData) -> None:
|
||||
if data.tenant is not None:
|
||||
return
|
||||
workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
|
||||
if not workspace_id:
|
||||
raise NotFound("workspace not found")
|
||||
try:
|
||||
uuid.UUID(workspace_id)
|
||||
except ValueError:
|
||||
raise NotFound("workspace not found")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, workspace_id)
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise NotFound("workspace not found")
|
||||
data.tenant = tenant
|
||||
|
||||
|
||||
def load_account(data: AuthData) -> None:
|
||||
if data.caller is not None:
|
||||
return
|
||||
account = AccountService.get_account_by_id(db.session, str(data.account_id))
|
||||
if account is None:
|
||||
raise Unauthorized("account not found")
|
||||
@ -41,6 +64,19 @@ def load_account(data: AuthData) -> None:
|
||||
data.caller_kind = "account"
|
||||
|
||||
|
||||
def load_workspace_role(data: AuthData) -> None:
|
||||
if data.tenant_role is not None:
|
||||
return
|
||||
if data.tenant is None or data.account_id is None:
|
||||
return
|
||||
if data.caller is not None and getattr(data.caller, "status", None) != "active":
|
||||
return
|
||||
role = TenantService.get_account_role_in_tenant(db.session, str(data.account_id), str(data.tenant.id))
|
||||
if role is None:
|
||||
return
|
||||
data.tenant_role = role
|
||||
|
||||
|
||||
def resolve_external_user(data: AuthData) -> None:
|
||||
if data.tenant is None or data.app is None or data.external_identity is None:
|
||||
raise Unauthorized("missing context for external user resolution")
|
||||
|
||||
@ -1,77 +0,0 @@
|
||||
"""Workspace role gate.
|
||||
|
||||
Layered on top of `validate_bearer` + `accept_subjects(SubjectType.ACCOUNT)`
|
||||
for routes whose access depends on the caller's `TenantAccountJoin.role`
|
||||
in the workspace named by the `workspace_id` path parameter.
|
||||
|
||||
Usage::
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members")
|
||||
class Members(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
@require_workspace_role() # any member
|
||||
def get(self, workspace_id: str): ...
|
||||
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def post(self, workspace_id: str): ...
|
||||
|
||||
Non-member callers get 404 (matching `GET /openapi/v1/workspaces/<id>`)
|
||||
so workspace IDs do not leak across tenants. A member without one of the
|
||||
allowed roles gets 403.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TypeVar
|
||||
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import try_get_auth_ctx
|
||||
from models.account import TenantAccountRole
|
||||
from services.account_service import TenantService
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
|
||||
def require_workspace_role(*allowed_roles: TenantAccountRole) -> Callable[[F], F]:
|
||||
"""Gate a route on the caller's role in ``workspace_id``.
|
||||
|
||||
Pass no roles to require only membership. Pass one or more roles to
|
||||
require the caller's role be in that set.
|
||||
"""
|
||||
|
||||
allowed = frozenset(allowed_roles)
|
||||
|
||||
def deco(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: object, **kwargs: object) -> object:
|
||||
ctx = try_get_auth_ctx()
|
||||
if ctx is None or ctx.account_id is None:
|
||||
raise RuntimeError(
|
||||
"require_workspace_role called without account-bearer context; "
|
||||
"stack validate_bearer + accept_subjects(SubjectType.ACCOUNT) above it"
|
||||
)
|
||||
|
||||
workspace_id = kwargs.get("workspace_id")
|
||||
if not workspace_id:
|
||||
raise RuntimeError("require_workspace_role expects a 'workspace_id' route parameter")
|
||||
|
||||
role = TenantService.get_account_role_in_tenant(db.session, str(ctx.account_id), str(workspace_id))
|
||||
|
||||
if role is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
if allowed and role not in allowed:
|
||||
raise Forbidden("insufficient workspace role")
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return deco
|
||||
@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Forbidden, NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType, check_workspace_membership
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
|
||||
|
||||
@ -17,17 +18,39 @@ def check_scope(data: AuthData) -> None:
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
def check_membership(data: AuthData) -> None:
|
||||
def check_workspace_member(data: AuthData) -> None:
|
||||
"""Assert the caller belongs to the resolved tenant.
|
||||
|
||||
`load_workspace_role` stashes the membership role (None when the caller is
|
||||
not a member or is inactive). A missing membership surfaces as 404, not
|
||||
403, so workspace IDs don't leak across tenants.
|
||||
"""
|
||||
if data.tenant_role is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
|
||||
def check_workspace_mismatch(data: AuthData) -> None:
|
||||
if data.tenant is None:
|
||||
raise Unauthorized("tenant unset")
|
||||
if data.account_id is None:
|
||||
raise Unauthorized("account_id unset")
|
||||
check_workspace_membership(
|
||||
account_id=data.account_id,
|
||||
tenant_id=data.tenant.id,
|
||||
token_hash=data.token_hash,
|
||||
membership_cache=data.tenants,
|
||||
)
|
||||
return
|
||||
request_workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
|
||||
if request_workspace_id and request_workspace_id != str(data.tenant.id):
|
||||
raise UnprocessableEntity("workspace_id does not match app's workspace")
|
||||
|
||||
|
||||
def check_workspace_role(data: AuthData) -> None:
|
||||
if data.allowed_roles is None:
|
||||
return
|
||||
if data.tenant_role is None:
|
||||
raise NotFound("workspace not found")
|
||||
if data.tenant_role not in data.allowed_roles:
|
||||
raise Forbidden("insufficient workspace role")
|
||||
|
||||
|
||||
def check_app_api_enabled(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
return
|
||||
if not data.app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
|
||||
|
||||
def check_app_access(data: AuthData) -> None:
|
||||
|
||||
@ -26,7 +26,12 @@ from werkzeug.exceptions import BadRequest
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
AccountPayload,
|
||||
@ -42,7 +47,6 @@ from controllers.openapi._models import (
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_CONSOLE,
|
||||
@ -50,6 +54,7 @@ from libs.rate_limit import (
|
||||
LIMIT_LOOKUP_PUBLIC,
|
||||
rate_limit,
|
||||
)
|
||||
from models import Account
|
||||
from services.account_service import TenantService
|
||||
from services.oauth_device_flow import (
|
||||
ACCOUNT_ISSUER_SENTINEL,
|
||||
@ -206,11 +211,12 @@ class DeviceApproveApi(Resource):
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant: str, account: Account):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
account, tenant = current_account_with_tenant()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
found = store.load_by_user_code(user_code)
|
||||
|
||||
@ -5,9 +5,8 @@ endpoints. Account bearers (dfoa_) see every tenant they're a member of.
|
||||
External SSO bearers (dfoe_) have no account_id and so see an empty list —
|
||||
that matches /openapi/v1/account.
|
||||
|
||||
Member-management endpoints are gated by both `accept_subjects` (SSO out)
|
||||
and `require_workspace_role` (membership / role lookup against the path's
|
||||
``workspace_id``).
|
||||
Member-management endpoints use ``guard_workspace`` which enforces
|
||||
workspace membership and optional role requirements via the auth pipeline.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -37,7 +36,6 @@ from controllers.openapi._models import (
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.role_gate import require_workspace_role
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import Account, Tenant, TenantAccountJoin
|
||||
@ -152,8 +150,7 @@ class WorkspaceSwitchApi(Resource):
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role()
|
||||
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
account = _load_account(auth_data.account_id)
|
||||
|
||||
@ -179,8 +176,7 @@ class WorkspaceMembersApi(Resource):
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(MemberListQuery))
|
||||
@openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role()
|
||||
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData):
|
||||
try:
|
||||
query = MemberListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -202,8 +198,11 @@ class WorkspaceMembersApi(Resource):
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__])
|
||||
@openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberInvitePayload)
|
||||
inviter = _load_account(auth_data.account_id)
|
||||
@ -253,8 +252,11 @@ class WorkspaceMemberApi(Resource):
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
operator = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
@ -284,8 +286,11 @@ class WorkspaceMemberRoleApi(Resource):
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__])
|
||||
@openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberRoleUpdatePayload)
|
||||
operator = _load_account(auth_data.account_id)
|
||||
|
||||
@ -6,7 +6,7 @@ from flask_restx import Resource
|
||||
from flask_restx.api import HTTPStatus
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import query_params_from_model, register_schema_models
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
@ -32,8 +32,19 @@ class AnnotationReplyActionPayload(BaseModel):
|
||||
embedding_model_name: str = Field(description="Embedding model name")
|
||||
|
||||
|
||||
class AnnotationListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, description="Number of annotations per page")
|
||||
keyword: str = Field(default="", description="Keyword to search annotations")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList
|
||||
service_api_ns,
|
||||
AnnotationCreatePayload,
|
||||
AnnotationReplyActionPayload,
|
||||
AnnotationListQuery,
|
||||
Annotation,
|
||||
AnnotationList,
|
||||
)
|
||||
|
||||
|
||||
@ -100,6 +111,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
class AnnotationListApi(Resource):
|
||||
@service_api_ns.doc("list_annotations")
|
||||
@service_api_ns.doc(description="List annotations for the application")
|
||||
@service_api_ns.doc(params=query_params_from_model(AnnotationListQuery))
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Annotations retrieved successfully",
|
||||
@ -114,18 +126,18 @@ class AnnotationListApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
"""List annotations for the application."""
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
query = AnnotationListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
|
||||
app_model.id, query.page, query.limit, query.keyword
|
||||
)
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response = AnnotationList(
|
||||
data=annotation_models,
|
||||
has_more=len(annotation_list) == limit,
|
||||
limit=limit,
|
||||
has_more=len(annotation_list) == query.limit,
|
||||
limit=query.limit,
|
||||
total=total,
|
||||
page=page,
|
||||
page=query.page,
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
|
||||
@ -41,6 +41,15 @@ from services.errors.llm import InvokeRateLimitError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_agent_app_streaming(*, app_mode: AppMode, response_mode: str | None) -> bool:
|
||||
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
|
||||
if app_mode != AppMode.AGENT:
|
||||
return response_mode == "streaming"
|
||||
if response_mode == "blocking":
|
||||
raise BadRequest("Agent App only supports streaming response mode.")
|
||||
return True
|
||||
|
||||
|
||||
class CompletionRequestPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str = Field(default="")
|
||||
@ -197,7 +206,7 @@ class ChatApi(Resource):
|
||||
Supports conversation management and both blocking and streaming response modes.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
|
||||
@ -207,7 +216,7 @@ class ChatApi(Resource):
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
streaming = _resolve_agent_app_streaming(app_mode=app_mode, response_mode=payload.response_mode)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
@ -262,7 +271,7 @@ class ChatStopApi(Resource):
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""Stop a running chat message generation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppTaskService.stop_task(
|
||||
|
||||
@ -155,7 +155,7 @@ class ConversationApi(Resource):
|
||||
Supports pagination using last_id and limit parameters.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
query_args = ConversationListQuery.model_validate(request.args.to_dict())
|
||||
@ -199,7 +199,7 @@ class ConversationDetailApi(Resource):
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
"""Delete a specific conversation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -228,7 +228,7 @@ class ConversationRenameApi(Resource):
|
||||
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
"""Rename a conversation or auto-generate a name."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
@ -56,7 +56,7 @@ class MessageListApi(Resource):
|
||||
Retrieves messages with pagination support using first_id.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
query_args = MessageListQuery.model_validate(request.args.to_dict())
|
||||
@ -167,7 +167,7 @@ class MessageSuggestedApi(Resource):
|
||||
"""
|
||||
message_id_str = str(message_id)
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
try:
|
||||
|
||||
@ -12,7 +12,6 @@ from typing import Self
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request, send_file
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
@ -27,7 +26,12 @@ from controllers.common.errors import (
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.common.fields import UrlResponse
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import (
|
||||
query_params_from_model,
|
||||
register_enum_models,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.dataset.error import (
|
||||
@ -44,7 +48,13 @@ from core.errors.error import ProviderTokenNotInitError
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
from fields.base import ResponseModel
|
||||
from fields.document_fields import (
|
||||
DocumentListResponse,
|
||||
DocumentResponse,
|
||||
DocumentStatusListResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import SegmentStatus
|
||||
@ -107,6 +117,44 @@ class DocumentListQuery(BaseModel):
|
||||
status: str | None = Field(default=None, description="Document status filter")
|
||||
|
||||
|
||||
DOCUMENT_CREATE_BY_FILE_PARAMS = {
|
||||
"dataset_id": "Dataset ID",
|
||||
"file": {
|
||||
"in": "formData",
|
||||
"type": "file",
|
||||
"required": True,
|
||||
"description": "Document file to upload.",
|
||||
},
|
||||
"data": {
|
||||
"in": "formData",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"description": "Optional JSON string with document creation settings.",
|
||||
},
|
||||
}
|
||||
DOCUMENT_UPDATE_BY_FILE_PARAMS = {
|
||||
"dataset_id": "Dataset ID",
|
||||
"document_id": "Document ID",
|
||||
"file": {
|
||||
"in": "formData",
|
||||
"type": "file",
|
||||
"required": False,
|
||||
"description": "Replacement document file.",
|
||||
},
|
||||
"data": {
|
||||
"in": "formData",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"description": "Optional JSON string with document update settings.",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DocumentAndBatchResponse(ResponseModel):
|
||||
document: DocumentResponse
|
||||
batch: str
|
||||
|
||||
|
||||
register_enum_models(service_api_ns, RetrievalMethod)
|
||||
|
||||
register_schema_models(
|
||||
@ -121,7 +169,14 @@ register_schema_models(
|
||||
PreProcessingRule,
|
||||
Segmentation,
|
||||
)
|
||||
register_response_schema_models(service_api_ns, UrlResponse)
|
||||
register_response_schema_models(
|
||||
service_api_ns,
|
||||
UrlResponse,
|
||||
DocumentResponse,
|
||||
DocumentAndBatchResponse,
|
||||
DocumentListResponse,
|
||||
DocumentStatusListResponse,
|
||||
)
|
||||
|
||||
|
||||
def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
@ -188,8 +243,7 @@ def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
return dump_response(DocumentAndBatchResponse, {"document": document, "batch": batch}), 200
|
||||
|
||||
|
||||
def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
@ -248,8 +302,7 @@ def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
return dump_response(DocumentAndBatchResponse, {"document": document, "batch": batch}), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create-by-text")
|
||||
@ -267,6 +320,9 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
400: "Bad request - invalid parameters",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200, "Document created successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
@ -296,6 +352,9 @@ class DeprecatedDocumentAddByTextApi(DatasetApiResource):
|
||||
400: "Bad request - invalid parameters",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200, "Document created successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
@ -319,6 +378,9 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200, "Document updated successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
@ -347,6 +409,9 @@ class DeprecatedDocumentUpdateByTextApi(DatasetApiResource):
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200, "Document updated successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
@ -363,7 +428,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.doc("create_document_by_file")
|
||||
@service_api_ns.doc(description="Create a new document by uploading a file")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(consumes=["multipart/form-data"], params=DOCUMENT_CREATE_BY_FILE_PARAMS)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document created successfully",
|
||||
@ -371,6 +436,9 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
400: "Bad request - invalid file or parameters",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200, "Document created successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
@ -462,8 +530,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
return dump_response(DocumentAndBatchResponse, {"document": document, "batch": batch}), 200
|
||||
|
||||
|
||||
def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
@ -539,8 +606,7 @@ def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
|
||||
return documents_and_batch_fields, 200
|
||||
return dump_response(DocumentAndBatchResponse, {"document": document, "batch": document.batch}), 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
@ -558,7 +624,7 @@ class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
|
||||
"Use PATCH /datasets/{dataset_id}/documents/{document_id} instead."
|
||||
)
|
||||
)
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(consumes=["multipart/form-data"], params=DOCUMENT_UPDATE_BY_FILE_PARAMS)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document updated successfully",
|
||||
@ -566,6 +632,9 @@ class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200, "Document updated successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
@ -577,7 +646,7 @@ class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
|
||||
class DocumentListApi(DatasetApiResource):
|
||||
@service_api_ns.doc("list_documents")
|
||||
@service_api_ns.doc(description="List all documents in a dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", **query_params_from_model(DocumentListQuery)})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Documents retrieved successfully",
|
||||
@ -585,6 +654,9 @@ class DocumentListApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200, "Documents retrieved successfully", service_api_ns.models[DocumentListResponse.__name__]
|
||||
)
|
||||
def get(self, tenant_id, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
@ -618,14 +690,14 @@ class DocumentListApi(DatasetApiResource):
|
||||
)
|
||||
|
||||
response = {
|
||||
"data": marshal(documents, document_fields),
|
||||
"data": documents,
|
||||
"has_more": len(documents) == query_params.limit,
|
||||
"limit": query_params.limit,
|
||||
"total": paginated_documents.total,
|
||||
"page": query_params.page,
|
||||
}
|
||||
|
||||
return response
|
||||
return dump_response(DocumentListResponse, response)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/download-zip")
|
||||
@ -680,6 +752,11 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
404: "Dataset or documents not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Indexing status retrieved successfully",
|
||||
service_api_ns.models[DocumentStatusListResponse.__name__],
|
||||
)
|
||||
def get(self, tenant_id, dataset_id: UUID, batch: str):
|
||||
dataset_id_str = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
@ -729,9 +806,8 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
documents_status.append(marshal(document_dict, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
return data
|
||||
documents_status.append(document_dict)
|
||||
return dump_response(DocumentStatusListResponse, {"data": documents_status})
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/download")
|
||||
@ -890,7 +966,7 @@ class DocumentApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.doc("update_document_by_file")
|
||||
@service_api_ns.doc(description="Update an existing document by uploading a file")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(consumes=["multipart/form-data"], params=DOCUMENT_UPDATE_BY_FILE_PARAMS)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document updated successfully",
|
||||
@ -898,6 +974,9 @@ class DocumentApi(DatasetApiResource):
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200, "Document updated successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
|
||||
@ -1,15 +1,18 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import (
|
||||
query_params_from_model,
|
||||
query_params_from_request,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import (
|
||||
@ -22,10 +25,19 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from fields.base import ResponseModel
|
||||
from fields.segment_fields import (
|
||||
ChildChunkDetailResponse,
|
||||
ChildChunkListResponse,
|
||||
SegmentDetailResponse,
|
||||
SegmentResponse,
|
||||
segment_response_with_summary,
|
||||
segment_responses_with_summaries,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||
@ -34,35 +46,27 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
|
||||
"""Marshal a single segment and enrich it with summary content."""
|
||||
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
class SegmentCreateItemPayload(BaseModel):
|
||||
content: str = Field(min_length=1)
|
||||
answer: str | None = None
|
||||
keywords: list[str] | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
|
||||
"""Marshal multiple segments and enrich them with summary content (batch query)."""
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries: dict[str, str | None] = {}
|
||||
if segment_ids:
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
for segment in segments:
|
||||
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
result.append(segment_dict)
|
||||
return result
|
||||
@field_validator("content")
|
||||
@classmethod
|
||||
def validate_content(cls, value: str) -> str:
|
||||
if not value.strip():
|
||||
raise ValueError("Content is empty")
|
||||
return value
|
||||
|
||||
|
||||
class SegmentCreatePayload(BaseModel):
|
||||
segments: list[dict[str, Any]] | None = None
|
||||
segments: list[SegmentCreateItemPayload] = Field(min_length=1)
|
||||
|
||||
|
||||
class SegmentListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1)
|
||||
page: int = Field(default=1, ge=1)
|
||||
status: list[str] = Field(default_factory=list)
|
||||
keyword: str | None = None
|
||||
|
||||
@ -77,9 +81,31 @@ class ChildChunkListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1)
|
||||
|
||||
|
||||
class SegmentDocParams:
|
||||
DATASET_DOCUMENT = {"dataset_id": "Dataset ID", "document_id": "Document ID"}
|
||||
DATASET_DOCUMENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Segment ID"}
|
||||
DATASET_DOCUMENT_PARENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Parent segment ID"}
|
||||
DATASET_DOCUMENT_CHILD_CHUNK = {**DATASET_DOCUMENT_PARENT_SEGMENT, "child_chunk_id": "Child chunk ID"}
|
||||
|
||||
|
||||
class SegmentCreateListResponse(ResponseModel):
|
||||
data: list[SegmentResponse]
|
||||
doc_form: str
|
||||
|
||||
|
||||
class SegmentListResponse(ResponseModel):
|
||||
data: list[SegmentResponse]
|
||||
doc_form: str
|
||||
total: int
|
||||
has_more: bool
|
||||
limit: int
|
||||
page: int
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
SegmentCreatePayload,
|
||||
SegmentCreateItemPayload,
|
||||
SegmentListQuery,
|
||||
SegmentUpdateArgs,
|
||||
SegmentUpdatePayload,
|
||||
@ -87,6 +113,15 @@ register_schema_models(
|
||||
ChildChunkListQuery,
|
||||
ChildChunkUpdatePayload,
|
||||
)
|
||||
register_response_schema_models(
|
||||
service_api_ns,
|
||||
SegmentResponse,
|
||||
SegmentCreateListResponse,
|
||||
SegmentListResponse,
|
||||
SegmentDetailResponse,
|
||||
ChildChunkDetailResponse,
|
||||
ChildChunkListResponse,
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
@ -96,7 +131,7 @@ class SegmentApi(DatasetApiResource):
|
||||
@service_api_ns.expect(service_api_ns.models[SegmentCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_segments")
|
||||
@service_api_ns.doc(description="Create segments in a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Segments created successfully",
|
||||
@ -105,6 +140,11 @@ class SegmentApi(DatasetApiResource):
|
||||
404: "Dataset or document not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Segments created successfully",
|
||||
service_api_ns.models[SegmentCreateListResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
@ -144,26 +184,35 @@ class SegmentApi(DatasetApiResource):
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# validate args
|
||||
payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
if payload.segments is not None:
|
||||
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
|
||||
if segments_limit > 0 and len(payload.segments) > segments_limit:
|
||||
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
|
||||
try:
|
||||
payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
except ValidationError as e:
|
||||
return {"error": str(e)}, 400
|
||||
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
|
||||
if segments_limit > 0 and len(payload.segments) > segments_limit:
|
||||
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
|
||||
segment_items = [segment.model_dump(exclude_none=True) for segment in payload.segments]
|
||||
|
||||
for args_item in payload.segments:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
|
||||
return {
|
||||
"data": _marshal_segments_with_summary(segments, dataset_id_str),
|
||||
"doc_form": document.doc_form,
|
||||
}, 200
|
||||
else:
|
||||
return {"error": "Segments is required"}, 400
|
||||
for args_item in segment_items:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = cast(list[DocumentSegment], SegmentService.multi_create_segment(segment_items, document, dataset))
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries: dict[str, str | None] = {}
|
||||
if segment_ids:
|
||||
summary_records = SummaryIndexService.get_segments_summaries(
|
||||
segment_ids=segment_ids, dataset_id=dataset_id_str
|
||||
)
|
||||
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||
response = {
|
||||
"data": segment_responses_with_summaries(segments, summaries),
|
||||
"doc_form": document.doc_form,
|
||||
}
|
||||
return dump_response(SegmentCreateListResponse, response), 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[SegmentListQuery.__name__])
|
||||
@service_api_ns.doc("list_segments")
|
||||
@service_api_ns.doc(description="List segments in a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
|
||||
@service_api_ns.doc(params=query_params_from_model(SegmentListQuery))
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Segments retrieved successfully",
|
||||
@ -171,12 +220,22 @@ class SegmentApi(DatasetApiResource):
|
||||
404: "Dataset or document not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Segments retrieved successfully",
|
||||
service_api_ns.models[SegmentListResponse.__name__],
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Get segments."""
|
||||
# check dataset
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
args = query_params_from_request(
|
||||
SegmentListQuery,
|
||||
list_fields=("status",),
|
||||
use_defaults_for_malformed_ints=True,
|
||||
)
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
@ -205,13 +264,6 @@ class SegmentApi(DatasetApiResource):
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
args = SegmentListQuery.model_validate(
|
||||
{
|
||||
"status": request.args.getlist("status"),
|
||||
"keyword": request.args.get("keyword"),
|
||||
}
|
||||
)
|
||||
|
||||
segments, total = SegmentService.get_segments(
|
||||
document_id=document_id_str,
|
||||
tenant_id=current_tenant_id,
|
||||
@ -220,9 +272,16 @@ class SegmentApi(DatasetApiResource):
|
||||
page=page,
|
||||
limit=limit,
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries: dict[str, str | None] = {}
|
||||
if segment_ids:
|
||||
summary_records = SummaryIndexService.get_segments_summaries(
|
||||
segment_ids=segment_ids, dataset_id=dataset_id_str
|
||||
)
|
||||
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||
|
||||
response = {
|
||||
"data": _marshal_segments_with_summary(segments, dataset_id_str),
|
||||
"data": segment_responses_with_summaries(segments, summaries),
|
||||
"doc_form": document.doc_form,
|
||||
"total": total,
|
||||
"has_more": len(segments) == limit,
|
||||
@ -230,16 +289,14 @@ class SegmentApi(DatasetApiResource):
|
||||
"page": page,
|
||||
}
|
||||
|
||||
return response, 200
|
||||
return dump_response(SegmentListResponse, response), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
class DatasetSegmentApi(DatasetApiResource):
|
||||
@service_api_ns.doc("delete_segment")
|
||||
@service_api_ns.doc(description="Delete a specific segment")
|
||||
@service_api_ns.doc(
|
||||
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to delete"}
|
||||
)
|
||||
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Segment deleted successfully",
|
||||
@ -275,9 +332,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_segment")
|
||||
@service_api_ns.doc(description="Update a specific segment")
|
||||
@service_api_ns.doc(
|
||||
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to update"}
|
||||
)
|
||||
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Segment updated successfully",
|
||||
@ -285,6 +340,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
404: "Dataset, document, or segment not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(200, "Segment updated successfully", service_api_ns.models[SegmentDetailResponse.__name__])
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
@ -328,13 +384,16 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
|
||||
return {
|
||||
"data": _marshal_segment_with_summary(updated_segment, dataset_id_str),
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=updated_segment.id, dataset_id=dataset_id_str)
|
||||
response = {
|
||||
"data": segment_response_with_summary(updated_segment, summary.summary_content if summary else None),
|
||||
"doc_form": document.doc_form,
|
||||
}, 200
|
||||
}
|
||||
return dump_response(SegmentDetailResponse, response), 200
|
||||
|
||||
@service_api_ns.doc("get_segment")
|
||||
@service_api_ns.doc(description="Get a specific segment by ID")
|
||||
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Segment retrieved successfully",
|
||||
@ -342,6 +401,11 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
404: "Dataset, document, or segment not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Segment retrieved successfully",
|
||||
service_api_ns.models[SegmentDetailResponse.__name__],
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -364,7 +428,12 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
return {"data": _marshal_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
|
||||
response = {
|
||||
"data": segment_response_with_summary(segment, summary.summary_content if summary else None),
|
||||
"doc_form": document.doc_form,
|
||||
}
|
||||
return dump_response(SegmentDetailResponse, response), 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
@ -376,9 +445,7 @@ class ChildChunkApi(DatasetApiResource):
|
||||
@service_api_ns.expect(service_api_ns.models[ChildChunkCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_child_chunk")
|
||||
@service_api_ns.doc(description="Create a new child chunk for a segment")
|
||||
@service_api_ns.doc(
|
||||
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"}
|
||||
)
|
||||
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Child chunk created successfully",
|
||||
@ -386,6 +453,11 @@ class ChildChunkApi(DatasetApiResource):
|
||||
404: "Dataset, document, or segment not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Child chunk created successfully",
|
||||
service_api_ns.models[ChildChunkDetailResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
@ -437,14 +509,12 @@ class ChildChunkApi(DatasetApiResource):
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[ChildChunkListQuery.__name__])
|
||||
@service_api_ns.doc("list_child_chunks")
|
||||
@service_api_ns.doc(description="List child chunks for a segment")
|
||||
@service_api_ns.doc(
|
||||
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"}
|
||||
)
|
||||
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
|
||||
@service_api_ns.doc(params=query_params_from_model(ChildChunkListQuery))
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Child chunks retrieved successfully",
|
||||
@ -452,6 +522,11 @@ class ChildChunkApi(DatasetApiResource):
|
||||
404: "Dataset, document, or segment not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Child chunks retrieved successfully",
|
||||
service_api_ns.models[ChildChunkListResponse.__name__],
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Get child chunks."""
|
||||
@ -475,13 +550,7 @@ class ChildChunkApi(DatasetApiResource):
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
args = ChildChunkListQuery.model_validate(
|
||||
{
|
||||
"limit": request.args.get("limit", default=20, type=int),
|
||||
"keyword": request.args.get("keyword"),
|
||||
"page": request.args.get("page", default=1, type=int),
|
||||
}
|
||||
)
|
||||
args = query_params_from_request(ChildChunkListQuery, use_defaults_for_malformed_ints=True)
|
||||
|
||||
page = args.page
|
||||
limit = min(args.limit, 100)
|
||||
@ -491,13 +560,14 @@ class ChildChunkApi(DatasetApiResource):
|
||||
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
|
||||
)
|
||||
|
||||
return {
|
||||
"data": marshal(child_chunks.items, child_chunk_fields),
|
||||
response = {
|
||||
"data": child_chunks.items,
|
||||
"total": child_chunks.total,
|
||||
"total_pages": child_chunks.pages,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
}, 200
|
||||
}
|
||||
return dump_response(ChildChunkListResponse, response), 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
@ -508,14 +578,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.doc("delete_child_chunk")
|
||||
@service_api_ns.doc(description="Delete a specific child chunk")
|
||||
@service_api_ns.doc(
|
||||
params={
|
||||
"dataset_id": "Dataset ID",
|
||||
"document_id": "Document ID",
|
||||
"segment_id": "Parent segment ID",
|
||||
"child_chunk_id": "Child chunk ID to delete",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Child chunk deleted successfully",
|
||||
@ -549,7 +612,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate segment belongs to the specified document
|
||||
if str(segment.document_id) != str(document_id_str):
|
||||
if segment.document_id != document_id_str:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
@ -561,7 +624,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate child chunk belongs to the specified segment
|
||||
if str(child_chunk.segment_id) != str(segment.id):
|
||||
if child_chunk.segment_id != segment.id:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
try:
|
||||
@ -574,14 +637,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_child_chunk")
|
||||
@service_api_ns.doc(description="Update a specific child chunk")
|
||||
@service_api_ns.doc(
|
||||
params={
|
||||
"dataset_id": "Dataset ID",
|
||||
"document_id": "Document ID",
|
||||
"segment_id": "Parent segment ID",
|
||||
"child_chunk_id": "Child chunk ID to update",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Child chunk updated successfully",
|
||||
@ -589,6 +645,11 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
404: "Dataset, document, segment, or child chunk not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Child chunk updated successfully",
|
||||
service_api_ns.models[ChildChunkDetailResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
@ -616,7 +677,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate segment belongs to the specified document
|
||||
if str(segment.document_id) != str(document_id_str):
|
||||
if segment.document_id != document_id_str:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
@ -628,7 +689,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate child chunk belongs to the specified segment
|
||||
if str(child_chunk.segment_id) != str(segment.id):
|
||||
if child_chunk.segment_id != segment.id:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate args
|
||||
@ -639,4 +700,4 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
@ -37,6 +37,15 @@ from services.errors.llm import InvokeRateLimitError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_agent_app_streaming(*, app_mode: AppMode, response_mode: str | None) -> bool:
|
||||
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
|
||||
if app_mode != AppMode.AGENT:
|
||||
return response_mode == "streaming"
|
||||
if response_mode == "blocking":
|
||||
raise BadRequest("Agent App only supports streaming response mode.")
|
||||
return True
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the completion")
|
||||
query: str = Field(default="", description="Query text for completion")
|
||||
@ -171,13 +180,13 @@ class ChatApi(WebApiResource):
|
||||
)
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
payload = ChatMessagePayload.model_validate(web_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
streaming = _resolve_agent_app_streaming(app_mode=app_mode, response_mode=payload.response_mode)
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
@ -228,7 +237,7 @@ class ChatStopApi(WebApiResource):
|
||||
@web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppTaskService.stop_task(
|
||||
|
||||
@ -83,7 +83,7 @@ class ConversationListApi(WebApiResource):
|
||||
)
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
raw_args = request.args.to_dict()
|
||||
@ -129,7 +129,7 @@ class ConversationApi(WebApiResource):
|
||||
)
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -168,7 +168,7 @@ class ConversationRenameApi(WebApiResource):
|
||||
)
|
||||
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -206,7 +206,7 @@ class ConversationPinApi(WebApiResource):
|
||||
@web_ns.response(200, "Conversation pinned successfully", web_ns.models[ResultResponse.__name__])
|
||||
def patch(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -237,7 +237,7 @@ class ConversationUnPinApi(WebApiResource):
|
||||
@web_ns.response(200, "Conversation unpinned successfully", web_ns.models[ResultResponse.__name__])
|
||||
def patch(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
@ -83,7 +83,7 @@ class MessageListApi(WebApiResource):
|
||||
)
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
raw_args = request.args.to_dict()
|
||||
@ -225,7 +225,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
)
|
||||
def get(self, app_model: App, end_user: EndUser, message_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
|
||||
@ -9,7 +9,7 @@ from controllers.common.errors import (
|
||||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
@ -60,10 +60,10 @@ class RemoteFileInfoApi(WebApiResource):
|
||||
HTTPException: If the remote file cannot be accessed
|
||||
"""
|
||||
decoded_url = helpers.decode_remote_url(url, request.query_string)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
resp = remote_fetcher.make_request("HEAD", decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# failed back to get method
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp = remote_fetcher.make_request("GET", decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
info = RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
@ -112,9 +112,9 @@ class RemoteFileUploadApi(WebApiResource):
|
||||
url = str(payload.url)
|
||||
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
resp = remote_fetcher.make_request("HEAD", url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
resp = remote_fetcher.make_request("GET", url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
@ -125,7 +125,7 @@ class RemoteFileUploadApi(WebApiResource):
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
raise FileTooLargeError
|
||||
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
content = resp.content if resp.request.method == "GET" else remote_fetcher.make_request("GET", url).content
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
|
||||
@ -397,39 +397,40 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
current_scratchpad: AgentScratchpadUnit | None = None
|
||||
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
if not current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
scratchpads.append(current_scratchpad)
|
||||
if message.tool_calls:
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
match message:
|
||||
case AssistantPromptMessage():
|
||||
if not current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except Exception:
|
||||
logger.exception("Failed to parse tool call from assistant message")
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad.observation = message.content
|
||||
else:
|
||||
raise NotImplementedError("expected str type")
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
scratchpads = []
|
||||
current_scratchpad = None
|
||||
scratchpads.append(current_scratchpad)
|
||||
if message.tool_calls:
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except Exception:
|
||||
logger.exception("Failed to parse tool call from assistant message")
|
||||
case ToolPromptMessage():
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad.observation = message.content
|
||||
else:
|
||||
raise NotImplementedError("expected str type")
|
||||
case UserPromptMessage():
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
scratchpads = []
|
||||
current_scratchpad = None
|
||||
|
||||
result.append(message)
|
||||
result.append(message)
|
||||
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
|
||||
@ -39,16 +39,17 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
historic_prompt = ""
|
||||
|
||||
for message in historic_prompt_messages:
|
||||
if isinstance(message, UserPromptMessage):
|
||||
historic_prompt += f"Question: {message.content}\n\n"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
if isinstance(message.content, str):
|
||||
historic_prompt += message.content + "\n\n"
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
historic_prompt += content.data
|
||||
match message:
|
||||
case UserPromptMessage():
|
||||
historic_prompt += f"Question: {message.content}\n\n"
|
||||
case AssistantPromptMessage():
|
||||
if isinstance(message.content, str):
|
||||
historic_prompt += message.content + "\n\n"
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
historic_prompt += content.data
|
||||
|
||||
return historic_prompt
|
||||
|
||||
|
||||
@ -237,7 +237,9 @@ class EasyUIBasedAppConfig(AppConfig):
|
||||
"""
|
||||
|
||||
app_model_config_from: EasyUIBasedAppModelConfigFrom
|
||||
app_model_config_id: str
|
||||
# Optional: an Agent App has no legacy app_model_config row, so the id may be
|
||||
# absent (persistence then stores NULL for the conversation's id).
|
||||
app_model_config_id: str | None = None
|
||||
app_model_config_dict: dict[str, Any]
|
||||
model: ModelConfigEntity
|
||||
prompt_template: PromptTemplateEntity
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity
|
||||
from graphon.variables.input_entities import VariableEntity
|
||||
@ -20,10 +22,32 @@ class WorkflowVariablesConfigManager:
|
||||
|
||||
# variables
|
||||
for variable in user_input_form:
|
||||
cls._normalize_json_schema(variable)
|
||||
variables.append(VariableEntity.model_validate(variable))
|
||||
|
||||
return variables
|
||||
|
||||
@staticmethod
|
||||
def _normalize_json_schema(variable: dict[str, Any]) -> None:
|
||||
"""
|
||||
Normalize ``json_schema`` from a JSON string to a dict.
|
||||
|
||||
The workflow graph is stored as JSON in the database. When a JSON
|
||||
object variable carries a ``json_schema`` field, nested dicts are
|
||||
preserved correctly, but older data or certain serialization paths
|
||||
may store it as a JSON *string* instead of a native dict.
|
||||
|
||||
``VariableEntity.json_schema`` expects ``dict | None``, so we
|
||||
deserialize the string here before handing it to Pydantic.
|
||||
"""
|
||||
json_schema = variable.get("json_schema")
|
||||
if isinstance(json_schema, str):
|
||||
try:
|
||||
variable["json_schema"] = json.loads(json_schema)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Leave as-is; Pydantic validation will surface the error.
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]:
|
||||
"""
|
||||
|
||||
@ -134,17 +134,18 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
match sub_stream_response:
|
||||
case MessageEndStreamResponse():
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
case ErrorStreamResponse():
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
case NodeStartStreamResponse() | NodeFinishStreamResponse():
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
case _:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -161,16 +161,17 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if isinstance(user, EndUser):
|
||||
self._user_id = user.id
|
||||
user_session_id = user.session_id
|
||||
self._created_by_role = CreatorUserRole.END_USER
|
||||
elif isinstance(user, Account):
|
||||
self._user_id = user.id
|
||||
user_session_id = user.id
|
||||
self._created_by_role = CreatorUserRole.ACCOUNT
|
||||
else:
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
match user:
|
||||
case EndUser():
|
||||
self._user_id = user.id
|
||||
user_session_id = user.session_id
|
||||
self._created_by_role = CreatorUserRole.END_USER
|
||||
case Account():
|
||||
self._user_id = user.id
|
||||
user_session_id = user.id
|
||||
self._created_by_role = CreatorUserRole.ACCOUNT
|
||||
case _:
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
|
||||
self._workflow_system_variables = build_system_variables(
|
||||
query=message.query,
|
||||
|
||||
0
api/core/app/apps/agent_app/__init__.py
Normal file
0
api/core/app/apps/agent_app/__init__.py
Normal file
106
api/core/app/apps/agent_app/app_config_manager.py
Normal file
106
api/core/app/apps/agent_app/app_config_manager.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""Build the EasyUI-style app config for an Agent App from its Agent Soul.
|
||||
|
||||
An Agent App has no legacy ``app_model_config``: its model / prompt live in the
|
||||
bound Agent Soul snapshot. To ride the existing chat message + SSE pipeline we
|
||||
synthesize an ``app_model_config``-shaped dict from the Soul (model + system
|
||||
prompt) plus any app-level feature flags (opening statement, follow-up, …)
|
||||
stored on ``app_model_config`` when present, then reuse the same sub-managers
|
||||
the chat app type uses.
|
||||
"""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager
|
||||
from core.app.app_config.entities import (
|
||||
EasyUIBasedAppConfig,
|
||||
EasyUIBasedAppModelConfigFrom,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
|
||||
|
||||
|
||||
class AgentAppConfig(EasyUIBasedAppConfig):
|
||||
"""Agent App config entity (EasyUI-shaped so it rides the chat pipeline).
|
||||
|
||||
``app_model_config_id`` is inherited as ``str | None``: an Agent App may have
|
||||
no legacy ``app_model_config`` row, in which case persistence stores ``NULL``
|
||||
for the conversation's ``app_model_config_id``.
|
||||
"""
|
||||
|
||||
|
||||
class AgentAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(
|
||||
cls,
|
||||
*,
|
||||
app_model: App,
|
||||
agent_soul: AgentSoulConfig,
|
||||
app_model_config: AppModelConfig | None = None,
|
||||
conversation: Conversation | None = None,
|
||||
) -> AgentAppConfig:
|
||||
"""Build the Agent App config from the Agent Soul (+ optional feature flags)."""
|
||||
config_dict = cls._synthesize_config_dict(agent_soul, app_model_config)
|
||||
# The synthesized dict is shaped like an app_model_config; the EasyUI
|
||||
# sub-managers type their param as AppModelConfigDict (a TypedDict).
|
||||
typed_config = cast(AppModelConfigDict, config_dict)
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
app_config = AgentAppConfig(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
app_mode=app_mode,
|
||||
# The config is derived from the Agent Soul snapshot, not a legacy
|
||||
# app_model_config row; the id is informational only.
|
||||
app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG,
|
||||
app_model_config_id=app_model_config.id if app_model_config else None,
|
||||
app_model_config_dict=config_dict,
|
||||
model=ModelConfigManager.convert(config=typed_config),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=typed_config),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
dataset=DatasetConfigManager.convert(config=typed_config),
|
||||
additional_features=cls.convert_features(config_dict, app_mode),
|
||||
)
|
||||
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
|
||||
config=typed_config
|
||||
)
|
||||
return app_config
|
||||
|
||||
@staticmethod
|
||||
def _synthesize_config_dict(
|
||||
agent_soul: AgentSoulConfig,
|
||||
app_model_config: AppModelConfig | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Shape a Soul + feature flags into an ``app_model_config``-style dict.
|
||||
|
||||
Feature flags (opening statement / follow-up / tts / stt / citations /
|
||||
moderation / annotation) come from ``app_model_config`` when present
|
||||
(Q3: stored there), otherwise defaults; model + prompt always come from
|
||||
the Agent Soul (the single source of truth for those).
|
||||
"""
|
||||
base: dict[str, Any] = dict(app_model_config.to_dict()) if app_model_config else {}
|
||||
|
||||
model = agent_soul.model
|
||||
if model is not None:
|
||||
base["model"] = {
|
||||
"provider": model.model_provider,
|
||||
"name": model.model,
|
||||
"mode": "chat",
|
||||
"completion_params": model.model_settings.model_dump(mode="json", exclude_none=True),
|
||||
}
|
||||
# The Agent Soul system prompt rides the EasyUI "simple" prompt slot; the
|
||||
# agent backend is the real prompt authority, this only feeds the chat
|
||||
# pipeline's bookkeeping (token counting, persistence).
|
||||
base["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
|
||||
base["pre_prompt"] = agent_soul.prompt.system_prompt or ""
|
||||
# Agent App takes the user message directly; no completion-style inputs form.
|
||||
base.setdefault("user_input_form", [])
|
||||
return base
|
||||
|
||||
|
||||
__all__ = ["AgentAppConfig", "AgentAppConfigManager"]
|
||||
327
api/core/app/apps/agent_app/app_generator.py
Normal file
327
api/core/app/apps/agent_app/app_generator.py
Normal file
@ -0,0 +1,327 @@
|
||||
"""Agent App generator: orchestrate one conversation turn for an Agent App.
|
||||
|
||||
Mirrors the agent_chat generator (conversation + message + queue + streamed
|
||||
response over the EasyUI chat pipeline), but the backing config comes from the
|
||||
bound Agent Soul and the answer is produced by ``AgentAppRunner`` calling the
|
||||
dify-agent backend rather than an in-process LLM/ReAct loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
|
||||
from clients.agent_backend import AgentBackendRunEventAdapter
|
||||
from clients.agent_backend.factory import create_agent_backend_run_client
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.apps.agent_app.app_config_manager import AgentAppConfigManager
|
||||
from core.app.apps.agent_app.app_runner import AgentAppRunner
|
||||
from core.app.apps.agent_app.generate_response_converter import AgentAppGenerateResponseConverter
|
||||
from core.app.apps.agent_app.runtime_request_builder import AgentAppRuntimeRequestBuilder
|
||||
from core.app.apps.agent_app.session_store import AgentAppRuntimeSessionStore
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AgentAppGenerateEntity,
|
||||
DifyRunContext,
|
||||
InvokeFrom,
|
||||
UserFrom,
|
||||
)
|
||||
from core.app.llm.model_access import build_dify_model_access
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, Message
|
||||
from models.agent import Agent, AgentConfigSnapshot, AgentScope, AgentSource, AgentStatus
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentAppGeneratorError(ValueError):
|
||||
"""Raised when an Agent App turn cannot be set up."""
|
||||
|
||||
|
||||
class AgentAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[Mapping | str, None, None]:
|
||||
if not streaming:
|
||||
raise AgentAppGeneratorError("Agent App only supports streaming mode")
|
||||
|
||||
query = args.get("query")
|
||||
if not isinstance(query, str) or not query.strip():
|
||||
raise AgentAppGeneratorError("query is required")
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
# Resolve the bound roster Agent + its published Agent Soul snapshot.
|
||||
agent, snapshot, agent_soul = self._resolve_agent(app_model)
|
||||
|
||||
conversation = None
|
||||
conversation_id = args.get("conversation_id")
|
||||
if conversation_id:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
|
||||
# Build the EasyUI-shaped config from the Agent Soul so the chat pipeline
|
||||
# can persist usage; the answer itself comes from the agent backend.
|
||||
app_model_config = app_model.app_model_config
|
||||
app_config = AgentAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
agent_soul=agent_soul,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
)
|
||||
model_conf = ModelConfigConverter.convert(app_config)
|
||||
|
||||
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
|
||||
|
||||
application_generate_entity = AgentAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=model_conf,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=[],
|
||||
parent_message_id=(
|
||||
args.get("parent_message_id")
|
||||
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
|
||||
else UUID_NIL
|
||||
),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras={"auto_generate_conversation_name": args.get("auto_generate_name", True)},
|
||||
call_depth=0,
|
||||
trace_manager=trace_manager,
|
||||
agent_id=agent.id,
|
||||
agent_config_snapshot_id=snapshot.id,
|
||||
)
|
||||
|
||||
conversation, message = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
context = contextvars.copy_context()
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": context,
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
"user_from": UserFrom.ACCOUNT if isinstance(user, Account) else UserFrom.END_USER,
|
||||
},
|
||||
)
|
||||
worker_thread.start()
|
||||
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
return AgentAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
*,
|
||||
flask_app: Flask,
|
||||
context: contextvars.Context,
|
||||
application_generate_entity: AgentAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
user_from: UserFrom,
|
||||
) -> None:
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
try:
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
app_config = application_generate_entity.app_config
|
||||
|
||||
# Apply app-level input guards (content moderation + annotation
|
||||
# reply) before reaching the Agent backend, mirroring the EasyUI
|
||||
# chat / agent-chat runners. These can short-circuit the turn.
|
||||
app_model = db.session.get(App, app_config.app_id)
|
||||
if app_model is None:
|
||||
raise AgentAppGeneratorError("App not found")
|
||||
handled, query = self._run_input_guards(
|
||||
application_generate_entity=application_generate_entity,
|
||||
app_model=app_model,
|
||||
message=message,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
if handled:
|
||||
return
|
||||
|
||||
dify_context = DifyRunContext(
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_id=app_config.app_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
)
|
||||
credentials_provider, _ = build_dify_model_access(dify_context)
|
||||
_, _, agent_soul = self._resolve_agent_by_id(
|
||||
tenant_id=app_config.tenant_id,
|
||||
agent_id=application_generate_entity.agent_id,
|
||||
snapshot_id=application_generate_entity.agent_config_snapshot_id,
|
||||
)
|
||||
|
||||
runner = AgentAppRunner(
|
||||
request_builder=AgentAppRuntimeRequestBuilder(credentials_provider=credentials_provider),
|
||||
agent_backend_client=create_agent_backend_run_client(
|
||||
base_url=dify_config.AGENT_BACKEND_BASE_URL,
|
||||
use_fake=dify_config.AGENT_BACKEND_USE_FAKE,
|
||||
fake_scenario=dify_config.AGENT_BACKEND_FAKE_SCENARIO,
|
||||
),
|
||||
event_adapter=AgentBackendRunEventAdapter(),
|
||||
session_store=AgentAppRuntimeSessionStore(),
|
||||
)
|
||||
runner.run(
|
||||
dify_context=dify_context,
|
||||
agent_id=application_generate_entity.agent_id,
|
||||
agent_config_snapshot_id=application_generate_entity.agent_config_snapshot_id,
|
||||
agent_soul=agent_soul,
|
||||
conversation_id=conversation.id,
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
model_name=application_generate_entity.model_conf.model,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error in Agent App generate worker")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _run_input_guards(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: AgentAppGenerateEntity,
|
||||
app_model: App,
|
||||
message: Message,
|
||||
queue_manager: AppQueueManager,
|
||||
) -> tuple[bool, str]:
|
||||
"""Apply input moderation + annotation reply before the backend call.
|
||||
|
||||
Returns ``(handled, query)``: when ``handled`` is True a direct answer
|
||||
has already been published (a blocked/preset moderation response or a
|
||||
matched annotation) and the backend turn must be skipped. Otherwise
|
||||
``query`` is the possibly moderation-overridden query to send onward.
|
||||
"""
|
||||
from core.app.apps.agent_app.app_runner import publish_text_answer
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
|
||||
app_config = application_generate_entity.app_config
|
||||
model_name = application_generate_entity.model_conf.model
|
||||
query = application_generate_entity.query
|
||||
|
||||
# content moderation (sensitive_word_avoidance); a blocked input yields a
|
||||
# preset answer, an "overridden" action returns a sanitized query.
|
||||
try:
|
||||
_, _, query = InputModeration().check(
|
||||
app_id=app_config.app_id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_config=app_config,
|
||||
inputs=dict(application_generate_entity.inputs),
|
||||
query=query or "",
|
||||
message_id=message.id,
|
||||
trace_manager=application_generate_entity.trace_manager,
|
||||
)
|
||||
except ModerationError as e:
|
||||
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=str(e))
|
||||
return True, query
|
||||
|
||||
# annotation reply: a matching annotation answers the turn deterministically.
|
||||
if query:
|
||||
annotation_reply = AnnotationReplyFeature().query(
|
||||
app_record=app_model,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
)
|
||||
if annotation_reply:
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=annotation_reply.content)
|
||||
return True, query
|
||||
|
||||
return False, query
|
||||
|
||||
def _resolve_agent(self, app_model: App) -> tuple[Agent, AgentConfigSnapshot, AgentSoulConfig]:
|
||||
agent = db.session.scalar(
|
||||
select(Agent).where(
|
||||
Agent.app_id == app_model.id,
|
||||
Agent.scope == AgentScope.ROSTER,
|
||||
Agent.source == AgentSource.AGENT_APP,
|
||||
Agent.status == AgentStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
if agent is None:
|
||||
raise AgentAppGeneratorError("Agent App has no bound Agent")
|
||||
return self._resolve_agent_by_id(
|
||||
tenant_id=app_model.tenant_id, agent_id=agent.id, snapshot_id=agent.active_config_snapshot_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_agent_by_id(
|
||||
*, tenant_id: str, agent_id: str, snapshot_id: str | None
|
||||
) -> tuple[Agent, AgentConfigSnapshot, AgentSoulConfig]:
|
||||
agent = db.session.scalar(select(Agent).where(Agent.id == agent_id, Agent.tenant_id == tenant_id))
|
||||
if agent is None:
|
||||
raise AgentAppGeneratorError("Agent not found")
|
||||
if not snapshot_id:
|
||||
raise AgentAppGeneratorError("Agent has no published version")
|
||||
snapshot = db.session.scalar(select(AgentConfigSnapshot).where(AgentConfigSnapshot.id == snapshot_id))
|
||||
if snapshot is None:
|
||||
raise AgentAppGeneratorError("Agent published version not found")
|
||||
agent_soul = AgentSoulConfig.model_validate(snapshot.config_snapshot_dict)
|
||||
return agent, snapshot, agent_soul
|
||||
|
||||
|
||||
__all__ = ["AgentAppGenerator", "AgentAppGeneratorError"]
|
||||
200
api/core/app/apps/agent_app/app_runner.py
Normal file
200
api/core/app/apps/agent_app/app_runner.py
Normal file
@ -0,0 +1,200 @@
|
||||
"""Agent App runner: drive one conversation turn through the dify-agent backend.
|
||||
|
||||
Unlike the legacy ``AgentChatAppRunner`` (which runs an in-process ReAct loop),
|
||||
this runner delegates to the Agent backend: build the run request from the
|
||||
Agent Soul + conversation, create the run, consume its event stream, and
|
||||
republish the assistant answer as chat queue events so the existing
|
||||
EasyUI chat task pipeline persists the message and streams SSE. The conversation
|
||||
``session_snapshot`` is saved on success for multi-turn continuity (S3).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from clients.agent_backend import (
|
||||
AgentBackendError,
|
||||
AgentBackendInternalEventType,
|
||||
AgentBackendRunClient,
|
||||
AgentBackendRunEventAdapter,
|
||||
AgentBackendRunSucceededInternalEvent,
|
||||
AgentBackendStreamInternalEvent,
|
||||
)
|
||||
from core.app.apps.agent_app.runtime_request_builder import (
|
||||
AgentAppRuntimeBuildContext,
|
||||
AgentAppRuntimeRequestBuilder,
|
||||
)
|
||||
from core.app.apps.agent_app.session_store import AgentAppRuntimeSessionStore, AgentAppSessionScope
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext
|
||||
from core.app.entities.queue_entities import QueueLLMChunkEvent, QueueMessageEndEvent
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def publish_text_answer(*, queue_manager: AppQueueManager, model_name: str, answer: str) -> None:
|
||||
"""Publish a complete assistant answer as one chunk + message-end.
|
||||
|
||||
The EasyUI chat task pipeline consumes a QueueLLMChunkEvent stream followed
|
||||
by a QueueMessageEndEvent; emitting the whole answer as a single chunk lets
|
||||
both the backend-produced answer and short-circuited answers (moderation /
|
||||
annotation reply) share the exact same persistence + SSE path.
|
||||
"""
|
||||
chunk = LLMResultChunk(
|
||||
model=model_name,
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=answer)),
|
||||
)
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
|
||||
queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_name,
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content=answer),
|
||||
usage=LLMUsage.empty_usage(),
|
||||
),
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
|
||||
class AgentAppRunner:
|
||||
"""Runs one Agent App conversation turn against the Agent backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_builder: AgentAppRuntimeRequestBuilder,
|
||||
agent_backend_client: AgentBackendRunClient,
|
||||
event_adapter: AgentBackendRunEventAdapter,
|
||||
session_store: AgentAppRuntimeSessionStore,
|
||||
) -> None:
|
||||
self._request_builder = request_builder
|
||||
self._agent_backend_client = agent_backend_client
|
||||
self._event_adapter = event_adapter
|
||||
self._session_store = session_store
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
dify_context: DifyRunContext,
|
||||
agent_id: str,
|
||||
agent_config_snapshot_id: str,
|
||||
agent_soul: AgentSoulConfig,
|
||||
conversation_id: str,
|
||||
query: str,
|
||||
message_id: str,
|
||||
model_name: str,
|
||||
queue_manager: AppQueueManager,
|
||||
) -> None:
|
||||
scope = AgentAppSessionScope(
|
||||
tenant_id=dify_context.tenant_id,
|
||||
app_id=dify_context.app_id,
|
||||
conversation_id=conversation_id,
|
||||
agent_id=agent_id,
|
||||
agent_config_snapshot_id=agent_config_snapshot_id,
|
||||
)
|
||||
session_snapshot = self._session_store.load_active_snapshot(scope)
|
||||
|
||||
runtime = self._request_builder.build(
|
||||
AgentAppRuntimeBuildContext(
|
||||
dify_context=dify_context,
|
||||
agent_id=agent_id,
|
||||
agent_config_snapshot_id=agent_config_snapshot_id,
|
||||
agent_soul=agent_soul,
|
||||
conversation_id=conversation_id,
|
||||
user_query=query,
|
||||
idempotency_key=message_id,
|
||||
session_snapshot=session_snapshot,
|
||||
)
|
||||
)
|
||||
|
||||
create_response = self._agent_backend_client.create_run(runtime.request)
|
||||
terminal = self._consume_stream(create_response.run_id, queue_manager=queue_manager)
|
||||
|
||||
if not isinstance(terminal, AgentBackendRunSucceededInternalEvent):
|
||||
error = getattr(terminal, "error", None) or "Agent backend run did not complete successfully."
|
||||
raise AgentBackendError(str(error))
|
||||
|
||||
answer = self._extract_answer(terminal.output)
|
||||
self._publish_answer(queue_manager=queue_manager, model_name=model_name, answer=answer)
|
||||
self._save_session(scope=scope, backend_run_id=terminal.run_id, snapshot=terminal.session_snapshot)
|
||||
|
||||
def _consume_stream(self, run_id: str, *, queue_manager: AppQueueManager):
|
||||
terminal = None
|
||||
for public_event in self._agent_backend_client.stream_events(run_id):
|
||||
if queue_manager.is_stopped():
|
||||
self._cancel_run(run_id)
|
||||
raise GenerateTaskStoppedError()
|
||||
for internal_event in self._event_adapter.adapt(public_event):
|
||||
if queue_manager.is_stopped():
|
||||
self._cancel_run(run_id)
|
||||
raise GenerateTaskStoppedError()
|
||||
if internal_event.type in (
|
||||
AgentBackendInternalEventType.RUN_STARTED,
|
||||
AgentBackendInternalEventType.STREAM_EVENT,
|
||||
):
|
||||
# Stream deltas are accumulated by the backend into the
|
||||
# terminal output; token-level forwarding is an S3 refinement.
|
||||
if isinstance(internal_event, AgentBackendStreamInternalEvent):
|
||||
continue
|
||||
continue
|
||||
terminal = internal_event
|
||||
break
|
||||
if terminal is not None:
|
||||
break
|
||||
return terminal
|
||||
|
||||
def _cancel_run(self, run_id: str) -> None:
|
||||
try:
|
||||
self._agent_backend_client.cancel_run(run_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to cancel stopped Agent App backend run: run_id=%s", run_id, exc_info=True)
|
||||
|
||||
def _publish_answer(self, *, queue_manager: AppQueueManager, model_name: str, answer: str) -> None:
|
||||
# MVP: emit the full answer as a single chunk + message-end. The chat
|
||||
# task pipeline streams the chunk over SSE and persists the message.
|
||||
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=answer)
|
||||
|
||||
def _save_session(self, *, scope: AgentAppSessionScope, backend_run_id: str, snapshot: Any) -> None:
|
||||
try:
|
||||
self._session_store.save_active_snapshot(scope=scope, backend_run_id=backend_run_id, snapshot=snapshot)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist Agent App conversation session snapshot: "
|
||||
"tenant_id=%s app_id=%s conversation_id=%s agent_id=%s",
|
||||
scope.tenant_id,
|
||||
scope.app_id,
|
||||
scope.conversation_id,
|
||||
scope.agent_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_answer(output: JsonValue) -> str:
|
||||
"""Normalize the backend's terminal output to assistant text.
|
||||
|
||||
Free-text Agent Apps return a plain string; if a structured output is
|
||||
configured the value is a JSON object, which we serialize so the chat
|
||||
message always has a string body.
|
||||
"""
|
||||
if isinstance(output, str):
|
||||
return output
|
||||
if isinstance(output, dict):
|
||||
text = output.get("text")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return json.dumps(output, ensure_ascii=False)
|
||||
return json.dumps(output, ensure_ascii=False)
|
||||
|
||||
|
||||
__all__ = ["AgentAppRunner", "publish_text_answer"]
|
||||
15
api/core/app/apps/agent_app/generate_response_converter.py
Normal file
15
api/core/app/apps/agent_app/generate_response_converter.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""Response converter for the Agent App type.
|
||||
|
||||
The Agent App streams the same chatbot response shape as the chat / agent-chat
|
||||
app types, so it reuses that converter wholesale; kept as a distinct subclass so
|
||||
the app type owns its converter and can diverge later.
|
||||
"""
|
||||
|
||||
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
|
||||
|
||||
|
||||
class AgentAppGenerateResponseConverter(AgentChatAppGenerateResponseConverter):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["AgentAppGenerateResponseConverter"]
|
||||
177
api/core/app/apps/agent_app/runtime_request_builder.py
Normal file
177
api/core/app/apps/agent_app/runtime_request_builder.py
Normal file
@ -0,0 +1,177 @@
|
||||
"""Build dify-agent run requests for one Agent App conversation turn.
|
||||
|
||||
Mirrors the workflow ``WorkflowAgentRuntimeRequestBuilder`` but for the Agent
|
||||
App surface: the user prompt is the chat message (no workflow-node job / no
|
||||
previous-node context), and multi-turn continuity flows through the
|
||||
conversation-keyed ``session_snapshot`` plus the history layer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
|
||||
from dify_agent.protocol import CreateRunRequest
|
||||
|
||||
from clients.agent_backend import (
|
||||
AgentBackendAgentAppRunInput,
|
||||
AgentBackendModelConfig,
|
||||
AgentBackendRunRequestBuilder,
|
||||
redact_for_agent_backend_log,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext
|
||||
from core.workflow.nodes.agent_v2.plugin_tools_builder import (
|
||||
WorkflowAgentPluginToolsBuilder,
|
||||
WorkflowAgentPluginToolsBuildError,
|
||||
)
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
class AgentAppRuntimeRequestBuildError(ValueError):
|
||||
"""Raised when Agent App state cannot be mapped to a valid run request."""
|
||||
|
||||
def __init__(self, error_code: str, message: str) -> None:
|
||||
self.error_code = error_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CredentialsProvider(Protocol):
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AgentAppRuntimeBuildContext:
|
||||
dify_context: DifyRunContext
|
||||
agent_id: str
|
||||
agent_config_snapshot_id: str
|
||||
agent_soul: AgentSoulConfig
|
||||
conversation_id: str
|
||||
user_query: str
|
||||
idempotency_key: str
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AgentAppRuntimeRequest:
|
||||
request: CreateRunRequest
|
||||
redacted_request: dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class AgentAppRuntimeRequestBuilder:
|
||||
"""Build dify-agent run requests from Agent App conversation state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
credentials_provider: CredentialsProvider,
|
||||
request_builder: AgentBackendRunRequestBuilder | None = None,
|
||||
plugin_tools_builder: WorkflowAgentPluginToolsBuilder | None = None,
|
||||
) -> None:
|
||||
self._credentials_provider = credentials_provider
|
||||
self._request_builder = request_builder or AgentBackendRunRequestBuilder()
|
||||
self._plugin_tools_builder = plugin_tools_builder or WorkflowAgentPluginToolsBuilder()
|
||||
|
||||
def build(self, context: AgentAppRuntimeBuildContext) -> AgentAppRuntimeRequest:
|
||||
agent_soul = context.agent_soul
|
||||
if agent_soul.model is None:
|
||||
raise AgentAppRuntimeRequestBuildError(
|
||||
"agent_model_not_configured",
|
||||
"Agent App requires the Agent Soul model to be configured.",
|
||||
)
|
||||
|
||||
metadata = self._build_metadata(context)
|
||||
credentials = self._credentials_provider.fetch(agent_soul.model.model_provider, agent_soul.model.model)
|
||||
try:
|
||||
tools_layer = self._plugin_tools_builder.build(
|
||||
tenant_id=context.dify_context.tenant_id,
|
||||
app_id=context.dify_context.app_id,
|
||||
user_id=context.dify_context.user_id,
|
||||
tools=agent_soul.tools,
|
||||
invoke_from=context.dify_context.invoke_from,
|
||||
)
|
||||
except WorkflowAgentPluginToolsBuildError as error:
|
||||
raise AgentAppRuntimeRequestBuildError(error.error_code, str(error)) from error
|
||||
if tools_layer is not None:
|
||||
metadata["agent_tools"] = {
|
||||
"dify_tool_count": len(tools_layer.tools),
|
||||
"dify_tool_names": [tool.name or tool.tool_name for tool in tools_layer.tools],
|
||||
}
|
||||
|
||||
request = self._request_builder.build_for_agent_app(
|
||||
AgentBackendAgentAppRunInput(
|
||||
model=AgentBackendModelConfig(
|
||||
plugin_id=self._plugin_daemon_plugin_id(
|
||||
plugin_id=agent_soul.model.plugin_id,
|
||||
model_provider=agent_soul.model.model_provider,
|
||||
),
|
||||
model_provider=self._plugin_daemon_provider_name(agent_soul.model.model_provider),
|
||||
model=agent_soul.model.model,
|
||||
credentials=self._normalize_credentials(credentials),
|
||||
model_settings=agent_soul.model.model_settings.model_dump(mode="json", exclude_none=True),
|
||||
),
|
||||
execution_context=DifyExecutionContextLayerConfig(
|
||||
tenant_id=context.dify_context.tenant_id,
|
||||
user_id=context.dify_context.user_id,
|
||||
app_id=context.dify_context.app_id,
|
||||
conversation_id=context.conversation_id,
|
||||
agent_id=context.agent_id,
|
||||
agent_config_version_id=context.agent_config_snapshot_id,
|
||||
invoke_from="agent_app",
|
||||
),
|
||||
agent_soul_prompt=agent_soul.prompt.system_prompt or None,
|
||||
user_prompt=context.user_query,
|
||||
tools=tools_layer,
|
||||
session_snapshot=context.session_snapshot,
|
||||
idempotency_key=context.idempotency_key,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
redacted = cast(dict[str, Any], redact_for_agent_backend_log(request))
|
||||
return AgentAppRuntimeRequest(request=request, redacted_request=redacted, metadata=metadata)
|
||||
|
||||
@staticmethod
|
||||
def _build_metadata(context: AgentAppRuntimeBuildContext) -> dict[str, Any]:
|
||||
return {
|
||||
"tenant_id": context.dify_context.tenant_id,
|
||||
"app_id": context.dify_context.app_id,
|
||||
"conversation_id": context.conversation_id,
|
||||
"agent_id": context.agent_id,
|
||||
"agent_config_snapshot_id": context.agent_config_snapshot_id,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _plugin_daemon_plugin_id(*, plugin_id: str, model_provider: str) -> str:
|
||||
"""Return the transport plugin id expected by plugin-daemon headers."""
|
||||
if plugin_id.count("/") == 1:
|
||||
return plugin_id
|
||||
if plugin_id:
|
||||
return ModelProviderID(plugin_id).plugin_id
|
||||
return ModelProviderID(model_provider).plugin_id
|
||||
|
||||
@staticmethod
|
||||
def _plugin_daemon_provider_name(model_provider: str) -> str:
|
||||
"""Return the provider name expected by plugin-daemon dispatch payloads."""
|
||||
return ModelProviderID(model_provider).provider_name
|
||||
|
||||
@staticmethod
|
||||
def _normalize_credentials(credentials: Mapping[str, Any]) -> dict[str, str | int | float | bool | None]:
|
||||
normalized: dict[str, str | int | float | bool | None] = {}
|
||||
for key, value in credentials.items():
|
||||
if isinstance(value, str | int | float | bool) or value is None:
|
||||
normalized[key] = value
|
||||
else:
|
||||
normalized[key] = str(value)
|
||||
return normalized
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentAppRuntimeBuildContext",
|
||||
"AgentAppRuntimeRequest",
|
||||
"AgentAppRuntimeRequestBuildError",
|
||||
"AgentAppRuntimeRequestBuilder",
|
||||
]
|
||||
106
api/core/app/apps/agent_app/session_store.py
Normal file
106
api/core/app/apps/agent_app/session_store.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""Conversation-keyed Agent backend session store for the Agent App type.
|
||||
|
||||
Shares the unified ``agent_runtime_sessions`` table with the workflow Agent
|
||||
Node store, but owns rows with ``owner_type = conversation``: one Agent App
|
||||
conversation maps to one Agent session, so multi-turn chat re-enters the same
|
||||
``session_snapshot``. Cross-conversation memory (PRD Global / Per app) is a
|
||||
phase-2 concern and not modeled here.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.agent import (
|
||||
AgentRuntimeSession,
|
||||
AgentRuntimeSessionOwnerType,
|
||||
AgentRuntimeSessionStatus,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AgentAppSessionScope:
|
||||
"""Identity of one Agent App conversation session."""
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
agent_id: str
|
||||
agent_config_snapshot_id: str
|
||||
|
||||
|
||||
class AgentAppRuntimeSessionStore:
|
||||
"""Persists Agent backend session snapshots for Agent App conversations."""
|
||||
|
||||
def load_active_snapshot(self, scope: AgentAppSessionScope) -> CompositorSessionSnapshot | None:
|
||||
with session_factory.create_session() as session:
|
||||
row = session.scalar(self._active_stmt(scope))
|
||||
if row is None:
|
||||
return None
|
||||
return CompositorSessionSnapshot.model_validate_json(row.session_snapshot)
|
||||
|
||||
def save_active_snapshot(
|
||||
self,
|
||||
*,
|
||||
scope: AgentAppSessionScope,
|
||||
backend_run_id: str,
|
||||
snapshot: CompositorSessionSnapshot | None,
|
||||
) -> None:
|
||||
if snapshot is None:
|
||||
return
|
||||
snapshot_json = snapshot.model_dump_json()
|
||||
with session_factory.create_session() as session:
|
||||
row = session.scalar(self._scope_stmt(scope))
|
||||
if row is None:
|
||||
row = AgentRuntimeSession(
|
||||
tenant_id=scope.tenant_id,
|
||||
app_id=scope.app_id,
|
||||
owner_type=AgentRuntimeSessionOwnerType.CONVERSATION,
|
||||
agent_id=scope.agent_id,
|
||||
agent_config_snapshot_id=scope.agent_config_snapshot_id,
|
||||
conversation_id=scope.conversation_id,
|
||||
backend_run_id=backend_run_id,
|
||||
session_snapshot=snapshot_json,
|
||||
composition_layer_specs="[]",
|
||||
status=AgentRuntimeSessionStatus.ACTIVE,
|
||||
)
|
||||
session.add(row)
|
||||
else:
|
||||
row.backend_run_id = backend_run_id
|
||||
row.session_snapshot = snapshot_json
|
||||
row.status = AgentRuntimeSessionStatus.ACTIVE
|
||||
row.cleaned_at = None
|
||||
session.commit()
|
||||
|
||||
def mark_cleaned(self, *, scope: AgentAppSessionScope, backend_run_id: str | None = None) -> None:
|
||||
with session_factory.create_session() as session:
|
||||
row = session.scalar(self._active_stmt(scope))
|
||||
if row is None:
|
||||
return
|
||||
if backend_run_id is not None:
|
||||
row.backend_run_id = backend_run_id
|
||||
row.status = AgentRuntimeSessionStatus.CLEANED
|
||||
row.cleaned_at = naive_utc_now()
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def _scope_stmt(scope: AgentAppSessionScope):
|
||||
return select(AgentRuntimeSession).where(
|
||||
AgentRuntimeSession.owner_type == AgentRuntimeSessionOwnerType.CONVERSATION,
|
||||
AgentRuntimeSession.tenant_id == scope.tenant_id,
|
||||
AgentRuntimeSession.conversation_id == scope.conversation_id,
|
||||
AgentRuntimeSession.agent_id == scope.agent_id,
|
||||
AgentRuntimeSession.agent_config_snapshot_id == scope.agent_config_snapshot_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _active_stmt(cls, scope: AgentAppSessionScope):
|
||||
return cls._scope_stmt(scope).where(AgentRuntimeSession.status == AgentRuntimeSessionStatus.ACTIVE)
|
||||
|
||||
|
||||
__all__ = ["AgentAppRuntimeSessionStore", "AgentAppSessionScope"]
|
||||
@ -112,15 +112,16 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
match sub_stream_response:
|
||||
case MessageEndStreamResponse():
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
case ErrorStreamResponse():
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
case _:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -134,6 +134,10 @@ class AppQueueManager(ABC):
|
||||
self._check_for_sqlalchemy_models(event.model_dump())
|
||||
self._publish(event, pub_from)
|
||||
|
||||
def is_stopped(self) -> bool:
|
||||
"""Return whether the current task has been manually stopped."""
|
||||
return self._is_stopped()
|
||||
|
||||
@abstractmethod
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
@ -209,14 +213,16 @@ class AppQueueManager(ABC):
|
||||
|
||||
def _check_for_sqlalchemy_models(self, data: Any):
|
||||
# from entity to dict or list
|
||||
if isinstance(data, dict):
|
||||
for value in data.values():
|
||||
self._check_for_sqlalchemy_models(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
self._check_for_sqlalchemy_models(item)
|
||||
else:
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
|
||||
raise TypeError(
|
||||
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
|
||||
)
|
||||
match data:
|
||||
case dict():
|
||||
for value in data.values():
|
||||
self._check_for_sqlalchemy_models(value)
|
||||
case list():
|
||||
for item in data:
|
||||
self._check_for_sqlalchemy_models(item)
|
||||
case _:
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
|
||||
raise TypeError(
|
||||
"Critical Error: Passing SQLAlchemy Model instances that"
|
||||
" cause thread safety issues is not allowed."
|
||||
)
|
||||
|
||||
@ -305,26 +305,27 @@ class AppRunner:
|
||||
text += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if isinstance(content, str):
|
||||
text += content
|
||||
elif isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
if message_id and user_id and tenant_id:
|
||||
try:
|
||||
self._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except Exception:
|
||||
_logger.exception("Failed to handle multimodal image output")
|
||||
else:
|
||||
_logger.warning("Received multimodal output but missing required parameters")
|
||||
else:
|
||||
text += content.data if hasattr(content, "data") else str(content)
|
||||
match content:
|
||||
case str():
|
||||
text += content
|
||||
case TextPromptMessageContent():
|
||||
text += content.data
|
||||
case ImagePromptMessageContent():
|
||||
if message_id and user_id and tenant_id:
|
||||
try:
|
||||
self._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except Exception:
|
||||
_logger.exception("Failed to handle multimodal image output")
|
||||
else:
|
||||
_logger.warning("Received multimodal output but missing required parameters")
|
||||
case _:
|
||||
text += content.data if hasattr(content, "data") else str(content)
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
@ -112,15 +112,16 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
match sub_stream_response:
|
||||
case MessageEndStreamResponse():
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
case ErrorStreamResponse():
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
case _:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -276,17 +276,18 @@ class WorkflowResponseConverter:
|
||||
|
||||
created_by: CreatedByDict | dict[str, object] = {}
|
||||
user = self._user
|
||||
if isinstance(user, Account):
|
||||
created_by = AccountCreatedByDict(
|
||||
id=user.id,
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
)
|
||||
elif isinstance(user, EndUser):
|
||||
created_by = EndUserCreatedByDict(
|
||||
id=user.id,
|
||||
user=user.session_id,
|
||||
)
|
||||
match user:
|
||||
case Account():
|
||||
created_by = AccountCreatedByDict(
|
||||
id=user.id,
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
)
|
||||
case EndUser():
|
||||
created_by = EndUserCreatedByDict(
|
||||
id=user.id,
|
||||
user=user.session_id,
|
||||
)
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
@ -455,17 +456,18 @@ class WorkflowResponseConverter:
|
||||
|
||||
created_by: Mapping[str, object]
|
||||
user = creator_user
|
||||
if isinstance(user, Account):
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}
|
||||
else:
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"user": user.session_id,
|
||||
}
|
||||
match user:
|
||||
case Account():
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}
|
||||
case _:
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"user": user.session_id,
|
||||
}
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
@ -562,15 +564,16 @@ class WorkflowResponseConverter:
|
||||
outputs, outputs_truncated = self._truncate_mapping(encoded_outputs)
|
||||
metadata = self._merge_metadata(event.execution_metadata, snapshot)
|
||||
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
error_message = event.error
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
status = WorkflowNodeExecutionStatus.FAILED
|
||||
error_message = event.error
|
||||
else:
|
||||
status = WorkflowNodeExecutionStatus.EXCEPTION
|
||||
error_message = event.error
|
||||
match event:
|
||||
case QueueNodeSucceededEvent():
|
||||
status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
error_message = event.error
|
||||
case QueueNodeFailedEvent():
|
||||
status = WorkflowNodeExecutionStatus.FAILED
|
||||
error_message = event.error
|
||||
case _:
|
||||
status = WorkflowNodeExecutionStatus.EXCEPTION
|
||||
error_message = event.error
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
|
||||
@ -109,17 +109,18 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
match sub_stream_response:
|
||||
case MessageEndStreamResponse():
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
case ErrorStreamResponse():
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
case _:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -15,6 +15,7 @@ from core.app.entities.task_entities import (
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||
"""
|
||||
Convert blocking full response.
|
||||
@ -24,6 +25,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[Workflow
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
@ -33,6 +35,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[Workflow
|
||||
return cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -66,6 +69,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[Workflow
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -200,6 +200,21 @@ class AgentChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGe
|
||||
pass
|
||||
|
||||
|
||||
class AgentAppGenerateEntity(ChatAppGenerateEntity):
|
||||
"""
|
||||
Agent App (new Agent app type) Generate Entity.
|
||||
|
||||
Subclasses ``ChatAppGenerateEntity`` so it rides the exact same EasyUI chat
|
||||
pipeline (generator, task pipeline, message cycle) without widening every
|
||||
accepted-entity union. The answer is produced by the dify-agent backend
|
||||
rather than an in-process LLM call; ``model_conf`` is synthesized from the
|
||||
bound Agent Soul model so the chat task pipeline can persist usage.
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
agent_config_snapshot_id: str
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
"""
|
||||
Advanced Chat Application Generate Entity.
|
||||
|
||||
@ -46,13 +46,14 @@ class BasedGenerateTaskPipeline:
|
||||
e = event.error
|
||||
err: Exception
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
elif isinstance(e, InvokeError | ValueError):
|
||||
err = e
|
||||
else:
|
||||
description = getattr(e, "description", None)
|
||||
err = Exception(description if description is not None else str(e))
|
||||
match e:
|
||||
case InvokeAuthorizationError():
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
case InvokeError() | ValueError():
|
||||
err = e
|
||||
case _:
|
||||
description = getattr(e, "description", None)
|
||||
err = Exception(description if description is not None else str(e))
|
||||
|
||||
if not message_id or not session:
|
||||
return err
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user