Compare commits

..

3 Commits

945 changed files with 14692 additions and 34281 deletions

View File

@ -0,0 +1 @@
../../.agents/skills/frontend-query-mutation

View File

@ -1 +0,0 @@
../../.agents/skills/how-to-write-component

View File

@ -51,15 +51,6 @@ jobs:
with:
files: |
api/**
- name: Check dify-agent inputs
if: github.event_name != 'merge_group'
id: dify-agent-changes
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
dify-agent/**/*.py
dify-agent/pyproject.toml
dify-agent/uv.lock
- if: github.event_name != 'merge_group'
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
@ -85,17 +76,6 @@ jobs:
# Format code
uv run ruff format ..
- if: github.event_name != 'merge_group' && steps.dify-agent-changes.outputs.any_changed == 'true'
run: |
cd dify-agent
uv sync --dev
# fmt first to avoid line too long
uv run ruff format .
# Fix lint errors
uv run ruff check --fix .
# Format code
uv run ruff format .
- name: count migration progress
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
run: |

View File

@ -1,4 +1,4 @@
name: Deploy SaaS
name: Deploy Agent Dev
permissions:
contents: read
@ -7,7 +7,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/saas"
- "deploy/agent-dev"
types:
- completed
@ -16,13 +16,13 @@ jobs:
runs-on: depot-ubuntu-24.04
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/saas'
github.event.workflow_run.head_branch == 'deploy/agent-dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
with:
host: ${{ secrets.SAAS_DEV_SSH_HOST }}
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
${{ vars.SSH_SCRIPT_SAAS_DEV || secrets.SSH_SCRIPT_SAAS_DEV }}
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}

View File

@ -95,51 +95,6 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web
- name: Web tsslint
if: steps.changed-files.outputs.any_changed == 'true'
env:
NODE_OPTIONS: --max-old-space-size=4096
run: vp run lint:tss
- name: Web dead code check
if: steps.changed-files.outputs.any_changed == 'true'
run: vp run knip
ts-common-style:
name: TS Common
runs-on: depot-ubuntu-24.04
permissions:
checks: write
pull-requests: read
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
web/**
cli/**
e2e/**
sdks/nodejs-client/**
packages/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.nvmrc
eslint.config.mjs
.github/workflows/style.yml
.github/actions/setup-web/**
- name: Setup web environment
if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web
- name: Restore ESLint cache
if: steps.changed-files.outputs.any_changed == 'true'
id: eslint-cache-restore
@ -150,14 +105,28 @@ jobs:
restore-keys: |
${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
- name: Style check
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: .
run: vp run lint:ci
- name: Type check
- name: Web tsslint
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
env:
NODE_OPTIONS: --max-old-space-size=4096
run: vp run lint:tss
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: .
run: vp run type-check
- name: Web dead code check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: vp run knip
- name: Save ESLint cache
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5

3
.gitignore vendored
View File

@ -259,6 +259,3 @@ scripts/stress-test/reports/
.qoder/*
.context/
.eslintcache
# Vitest local reports
web/.vitest-reports/

View File

@ -1,27 +0,0 @@
# Security Policy
## Reporting a Vulnerability
If you believe you have found a security vulnerability in Dify, please report it privately through GitHub Security Advisories:
https://github.com/langgenius/dify/security/advisories/new
Please do not report security vulnerabilities through public GitHub issues, discussions, or pull requests.
When submitting a report, include as much relevant information as you can safely provide, such as:
- A description of the vulnerability
- Steps to reproduce, if safe to share privately
- Affected components, versions, or configurations
- Potential impact
- Any suggested mitigation or fix, if available
The maintainers will review reports submitted through GitHub Security Advisories and coordinate follow-up there.
## Public Disclosure
Please avoid publicly disclosing details of a vulnerability until it has been reviewed and, where appropriate, a fix or mitigation has been made available.
## Security Updates
Security fixes may be released through normal project releases or other appropriate channels. Users are encouraged to keep Dify deployments up to date.

View File

@ -34,7 +34,6 @@ from clients.agent_backend.request_builder import (
DIFY_PLUGIN_TOOLS_LAYER_ID,
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
WORKFLOW_USER_PROMPT_LAYER_ID,
AgentBackendAgentAppRunInput,
AgentBackendModelConfig,
AgentBackendOutputConfig,
AgentBackendRunRequestBuilder,
@ -50,7 +49,6 @@ __all__ = [
"DIFY_PLUGIN_TOOLS_LAYER_ID",
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
"WORKFLOW_USER_PROMPT_LAYER_ID",
"AgentBackendAgentAppRunInput",
"AgentBackendError",
"AgentBackendHTTPError",
"AgentBackendInternalEvent",

View File

@ -45,7 +45,6 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
AGENT_APP_USER_PROMPT_LAYER_ID = "agent_app_user_prompt"
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
DIFY_PLUGIN_TOOLS_LAYER_ID = "tools"
@ -182,138 +181,9 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
return value
class AgentBackendAgentAppRunInput(BaseModel):
"""Inputs to build one Agent App conversation-turn run request.
Unlike the workflow-node input there is no workflow-node-job prompt and no
previous-node context: the user prompt is the chat message, and multi-turn
continuity comes from ``session_snapshot`` + the history layer keyed by the
conversation.
"""
model: AgentBackendModelConfig
execution_context: DifyExecutionContextLayerConfig
user_prompt: str
agent_soul_prompt: str | None = None
purpose: RunPurpose = "agent_app"
idempotency_key: str | None = None
output: AgentBackendOutputConfig | None = None
tools: DifyPluginToolsLayerConfig | None = None
session_snapshot: CompositorSessionSnapshot | None = None
include_history: bool = True
suspend_on_exit: bool = True
metadata: dict[str, JsonValue] = Field(default_factory=dict)
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
@field_validator("user_prompt")
@classmethod
def _reject_blank_prompt(cls, value: str) -> str:
if not value.strip():
raise ValueError("prompt must not be blank")
return value
class AgentBackendRunRequestBuilder:
"""Converts API product state into the public ``dify-agent`` run protocol."""
def build_for_agent_app(self, run_input: AgentBackendAgentAppRunInput) -> CreateRunRequest:
"""Build an Agent App conversation-turn run request.
Layer graph: optional Agent Soul system prompt → user prompt →
execution context → optional history (multi-turn) → LLM → optional
plugin tools → optional structured output. Mirrors the workflow-node
layer ordering minus the workflow-job / previous-node prompt.
"""
layers: list[RunLayerSpec] = []
if run_input.agent_soul_prompt:
layers.append(
RunLayerSpec(
name=AGENT_SOUL_PROMPT_LAYER_ID,
type=PLAIN_PROMPT_LAYER_TYPE_ID,
metadata={**run_input.metadata, "origin": "agent_soul"},
config=PromptLayerConfig(prefix=run_input.agent_soul_prompt),
)
)
layers.extend(
[
RunLayerSpec(
name=AGENT_APP_USER_PROMPT_LAYER_ID,
type=PLAIN_PROMPT_LAYER_TYPE_ID,
metadata={**run_input.metadata, "origin": "agent_app_user_prompt"},
config=PromptLayerConfig(user=run_input.user_prompt),
),
RunLayerSpec(
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=run_input.execution_context,
),
]
)
if run_input.include_history:
layers.append(
RunLayerSpec(
name=DIFY_AGENT_HISTORY_LAYER_ID,
type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID,
metadata={**run_input.metadata, "origin": "agent_session_history"},
)
)
layers.append(
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
metadata=run_input.metadata,
config=DifyPluginLLMLayerConfig(
plugin_id=run_input.model.plugin_id,
model_provider=run_input.model.model_provider,
model=run_input.model.model,
credentials=run_input.model.credentials,
model_settings=run_input.model.model_settings or None,
),
)
)
if run_input.tools is not None and run_input.tools.tools:
layers.append(
RunLayerSpec(
name=DIFY_PLUGIN_TOOLS_LAYER_ID,
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
metadata=run_input.metadata,
config=run_input.tools,
)
)
if run_input.output is not None:
layers.append(
RunLayerSpec(
name=DIFY_AGENT_OUTPUT_LAYER_ID,
type=DIFY_OUTPUT_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=DifyOutputLayerConfig(
json_schema=run_input.output.json_schema,
description=run_input.output.description,
strict=run_input.output.strict,
),
)
)
return CreateRunRequest(
composition=RunComposition(layers=layers),
purpose=run_input.purpose,
idempotency_key=run_input.idempotency_key,
metadata=run_input.metadata,
session_snapshot=run_input.session_snapshot,
on_exit=LayerExitSignals(
default=ExitIntent.SUSPEND if run_input.suspend_on_exit else ExitIntent.DELETE,
),
)
def build_cleanup_request(
self,
*,

View File

@ -29,7 +29,6 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError
@override
def __call__(self) -> dict[str, Any]:
current_state = self.current_state
remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")

View File

@ -81,15 +81,4 @@ default_app_templates: Mapping[AppMode, Mapping] = {
},
},
},
# agent default mode (new Agent App type). The runtime model / prompt / tools
# come from the bound Agent Soul snapshot, so no model_config is seeded in the
# template; create_app still creates a model-less app_model_config row to hold
# app-level presentation features (opener, follow-up, citations, ...).
AppMode.AGENT: {
"app": {
"mode": AppMode.AGENT,
"enable_site": True,
"enable_api": True,
},
},
}

View File

@ -6,11 +6,10 @@ These helpers keep that translation centralized so models registered through
`register_schema_models` emit resolvable Swagger 2.0 references.
"""
from collections.abc import Iterable, Mapping
from collections.abc import Mapping
from enum import StrEnum
from typing import Any, Literal, NotRequired, Protocol, TypedDict
from typing import Any, Literal, NotRequired, TypedDict
from flask import request
from flask_restx import Namespace
from pydantic import BaseModel, TypeAdapter
@ -37,12 +36,6 @@ QueryParamDoc = TypedDict(
)
class QueryArgs(Protocol):
def to_dict(self, flat: bool = True) -> dict[str, str]: ...
def getlist(self, key: str) -> list[str]: ...
def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None:
"""Register a JSON schema and promote any nested Pydantic `$defs`."""
@ -174,58 +167,6 @@ def query_params_from_model(model: type[BaseModel]) -> dict[str, QueryParamDoc]:
return params
def query_params_from_request[ModelT: BaseModel](
model: type[ModelT],
*,
list_fields: Iterable[str] = (),
args: QueryArgs | None = None,
use_defaults_for_malformed_ints: bool = False,
) -> ModelT:
"""Validate query args with Pydantic while preserving Flask query parsing behavior.
Repeated params need explicit ``getlist()`` handling because Werkzeug's
``to_dict()`` keeps only one value. For malformed scalar integers, Flask's
For endpoints migrated from ``request.args.get(..., type=int, default=...)``,
set ``use_defaults_for_malformed_ints`` to preserve Flask's fallback to
defaults for malformed optional integer params.
"""
query_args = args or request.args
params: dict[str, Any] = query_args.to_dict()
for field_name in list_fields:
params[field_name] = query_args.getlist(field_name)
if use_defaults_for_malformed_ints:
_drop_malformed_defaulted_integer_params(model, params)
return model.model_validate(params)
def _drop_malformed_defaulted_integer_params(model: type[BaseModel], params: dict[str, Any]) -> None:
properties = model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0).get("properties", {})
if not isinstance(properties, Mapping):
return
for name, value in list(params.items()):
if not isinstance(value, str):
continue
field = model.model_fields.get(name)
if field is None or field.is_required():
continue
property_schema = properties.get(name)
if not isinstance(property_schema, Mapping):
continue
if _nullable_property_schema(property_schema).get("type") != "integer":
continue
try:
int(value)
except ValueError:
params.pop(name)
def _query_param_from_property(property_schema: Mapping[str, Any], *, required: bool) -> QueryParamDoc:
param_schema = _nullable_property_schema(property_schema)
param_doc: QueryParamDoc = {"in": "query", "required": required}
@ -298,7 +239,6 @@ __all__ = [
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
"get_or_create_model",
"query_params_from_model",
"query_params_from_request",
"register_enum_models",
"register_response_schema_model",
"register_response_schema_models",

View File

@ -51,8 +51,6 @@ from .agent import roster as agent_roster
from .app import (
advanced_prompt_template,
agent,
agent_app_access,
agent_app_feature,
annotation,
app,
audio,
@ -148,8 +146,6 @@ __all__ = [
"activate",
"advanced_prompt_template",
"agent",
"agent_app_access",
"agent_app_feature",
"agent_composer",
"agent_providers",
"agent_roster",

View File

@ -1,91 +1,53 @@
from flask_restx import Resource
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from fields.agent_fields import (
AgentAppComposerResponse,
AgentComposerCandidatesResponse,
AgentComposerImpactResponse,
AgentComposerValidateResponse,
WorkflowAgentComposerResponse,
)
from libs.helper import dump_response
from libs.login import login_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from libs.login import current_account_with_tenant, login_required
from models.model import App, AppMode
from services.agent.composer_service import AgentComposerService
from services.agent.composer_validator import ComposerConfigValidator
from services.entities.agent_entities import ComposerSavePayload
register_schema_models(console_ns, ComposerSavePayload)
register_response_schema_models(
console_ns,
AgentAppComposerResponse,
AgentComposerCandidatesResponse,
AgentComposerImpactResponse,
AgentComposerValidateResponse,
WorkflowAgentComposerResponse,
)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer")
class WorkflowAgentComposerApi(Resource):
@console_ns.response(
200, "Workflow agent composer state", console_ns.models[WorkflowAgentComposerResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App, node_id: str):
return dump_response(
WorkflowAgentComposerResponse,
AgentComposerService.load_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
),
def get(self, app_model: App, node_id: str):
_, tenant_id = current_account_with_tenant()
return AgentComposerService.load_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
)
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(
200, "Workflow agent composer saved", console_ns.models[WorkflowAgentComposerResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
@with_current_user_id
@with_current_tenant_id
def put(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
def put(self, app_model: App, node_id: str):
account, tenant_id = current_account_with_tenant()
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
return dump_response(
WorkflowAgentComposerResponse,
AgentComposerService.save_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account_id,
payload=payload,
),
return AgentComposerService.save_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account.id,
payload=payload,
)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/validate")
class WorkflowAgentComposerValidateApi(Resource):
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(
200, "Workflow agent composer validation result", console_ns.models[AgentComposerValidateResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@ -93,116 +55,84 @@ class WorkflowAgentComposerValidateApi(Resource):
def post(self, app_model: App, node_id: str):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
ComposerConfigValidator.validate_save_payload(payload)
return dump_response(AgentComposerValidateResponse, {"result": "success", "errors": []})
return {"result": "success", "errors": []}
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/candidates")
class WorkflowAgentComposerCandidatesApi(Resource):
@console_ns.response(
200, "Workflow agent composer candidates", console_ns.models[AgentComposerCandidatesResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def get(self, app_model: App, node_id: str):
return dump_response(
AgentComposerCandidatesResponse,
AgentComposerService.get_workflow_candidates(app_id=app_model.id),
)
return AgentComposerService.get_workflow_candidates(app_id=app_model.id)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/impact")
class WorkflowAgentComposerImpactApi(Resource):
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(200, "Workflow agent composer impact", console_ns.models[AgentComposerImpactResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
@with_current_tenant_id
def post(self, tenant_id: str, app_model: App, node_id: str):
def post(self, app_model: App, node_id: str):
_, tenant_id = current_account_with_tenant()
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None
if not current_snapshot_id:
return dump_response(
AgentComposerImpactResponse, {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
)
return dump_response(
AgentComposerImpactResponse,
AgentComposerService.calculate_impact(tenant_id=tenant_id, current_snapshot_id=current_snapshot_id),
)
return {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
return AgentComposerService.calculate_impact(tenant_id=tenant_id, current_snapshot_id=current_snapshot_id)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/save-to-roster")
class WorkflowAgentComposerSaveToRosterApi(Resource):
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(
200, "Workflow agent composer saved to roster", console_ns.models[WorkflowAgentComposerResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
def post(self, app_model: App, node_id: str):
account, tenant_id = current_account_with_tenant()
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
return dump_response(
WorkflowAgentComposerResponse,
AgentComposerService.save_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account_id,
payload=payload,
),
return AgentComposerService.save_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account.id,
payload=payload,
)
@console_ns.route("/apps/<uuid:app_id>/agent-composer")
class AgentAppComposerApi(Resource):
@console_ns.response(200, "Agent app composer state", console_ns.models[AgentAppComposerResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model()
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App):
return dump_response(
AgentAppComposerResponse,
AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id),
)
def get(self, app_model: App):
_, tenant_id = current_account_with_tenant()
return AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id)
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(200, "Agent app composer saved", console_ns.models[AgentAppComposerResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model()
@with_current_user_id
@with_current_tenant_id
def put(self, tenant_id: str, account_id: str, app_model: App):
def put(self, app_model: App):
account, tenant_id = current_account_with_tenant()
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
return dump_response(
AgentAppComposerResponse,
AgentComposerService.save_agent_app_composer(
tenant_id=tenant_id,
app_id=app_model.id,
account_id=account_id,
payload=payload,
),
return AgentComposerService.save_agent_app_composer(
tenant_id=tenant_id,
app_id=app_model.id,
account_id=account.id,
payload=payload,
)
@console_ns.route("/apps/<uuid:app_id>/agent-composer/validate")
class AgentAppComposerValidateApi(Resource):
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(
200, "Agent app composer validation result", console_ns.models[AgentComposerValidateResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@ -210,20 +140,14 @@ class AgentAppComposerValidateApi(Resource):
def post(self, app_model: App):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
ComposerConfigValidator.validate_save_payload(payload)
return dump_response(AgentComposerValidateResponse, {"result": "success", "errors": []})
return {"result": "success", "errors": []}
@console_ns.route("/apps/<uuid:app_id>/agent-composer/candidates")
class AgentAppComposerCandidatesApi(Resource):
@console_ns.response(
200, "Agent app composer candidates", console_ns.models[AgentComposerCandidatesResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@get_app_model()
def get(self, app_model: App):
return dump_response(
AgentComposerCandidatesResponse,
AgentComposerService.get_agent_app_candidates(app_id=app_model.id),
)
return AgentComposerService.get_agent_app_candidates(app_id=app_model.id)

View File

@ -4,25 +4,11 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
from fields.agent_fields import (
AgentConfigSnapshotDetailResponse,
AgentConfigSnapshotListResponse,
AgentInviteOptionsResponse,
AgentRosterListResponse,
AgentRosterResponse,
)
from libs.helper import dump_response
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.agent.roster_service import AgentRosterService
from services.entities.agent_entities import RosterAgentCreatePayload, RosterAgentUpdatePayload, RosterListQuery
@ -43,14 +29,6 @@ register_schema_models(
RosterAgentUpdatePayload,
RosterListQuery,
)
register_response_schema_models(
console_ns,
AgentConfigSnapshotDetailResponse,
AgentConfigSnapshotListResponse,
AgentInviteOptionsResponse,
AgentRosterListResponse,
AgentRosterResponse,
)
def _agent_roster_service() -> AgentRosterService:
@ -59,130 +37,96 @@ def _agent_roster_service() -> AgentRosterService:
@console_ns.route("/agents")
class AgentRosterListApi(Resource):
@console_ns.doc(params=query_params_from_model(RosterListQuery))
@console_ns.response(200, "Agent roster list", console_ns.models[AgentRosterListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str):
def get(self):
_, tenant_id = current_account_with_tenant()
query = RosterListQuery.model_validate(request.args.to_dict(flat=True))
return dump_response(
AgentRosterListResponse,
_agent_roster_service().list_roster_agents(
tenant_id=tenant_id, page=query.page, limit=query.limit, keyword=query.keyword
),
return _agent_roster_service().list_roster_agents(
tenant_id=tenant_id, page=query.page, limit=query.limit, keyword=query.keyword
)
@console_ns.expect(console_ns.models[RosterAgentCreatePayload.__name__])
@console_ns.response(201, "Agent created", console_ns.models[AgentRosterResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, account_id: str):
def post(self):
account, tenant_id = current_account_with_tenant()
payload = RosterAgentCreatePayload.model_validate(console_ns.payload or {})
service = _agent_roster_service()
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account_id, payload=payload)
return dump_response(
AgentRosterResponse,
service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id),
), 201
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account.id, payload=payload)
return service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id), 201
@console_ns.route("/agents/invite-options")
class AgentInviteOptionsApi(Resource):
@console_ns.doc(params=query_params_from_model(AgentInviteOptionsQuery))
@console_ns.response(200, "Agent invite options", console_ns.models[AgentInviteOptionsResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str):
def get(self):
_, tenant_id = current_account_with_tenant()
query = AgentInviteOptionsQuery.model_validate(request.args.to_dict(flat=True))
return dump_response(
AgentInviteOptionsResponse,
_agent_roster_service().list_invite_options(
tenant_id=tenant_id,
page=query.page,
limit=query.limit,
keyword=query.keyword,
app_id=query.app_id,
),
return _agent_roster_service().list_invite_options(
tenant_id=tenant_id,
page=query.page,
limit=query.limit,
keyword=query.keyword,
app_id=query.app_id,
)
@console_ns.route("/agents/<uuid:agent_id>")
class AgentRosterDetailApi(Resource):
@console_ns.response(200, "Agent detail", console_ns.models[AgentRosterResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID):
return dump_response(
AgentRosterResponse,
_agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id)),
)
def get(self, agent_id: UUID):
_, tenant_id = current_account_with_tenant()
return _agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id))
@console_ns.expect(console_ns.models[RosterAgentUpdatePayload.__name__])
@console_ns.response(200, "Agent updated", console_ns.models[AgentRosterResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user_id
@with_current_tenant_id
def patch(self, tenant_id: str, account_id: str, agent_id: UUID):
def patch(self, agent_id: UUID):
account, tenant_id = current_account_with_tenant()
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
return dump_response(
AgentRosterResponse,
_agent_roster_service().update_roster_agent(
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id, payload=payload
),
return _agent_roster_service().update_roster_agent(
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id, payload=payload
)
@console_ns.response(204, "Agent archived")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user_id
@with_current_tenant_id
def delete(self, tenant_id: str, account_id: str, agent_id: UUID):
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id)
def delete(self, agent_id: UUID):
account, tenant_id = current_account_with_tenant()
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
return "", 204
@console_ns.route("/agents/<uuid:agent_id>/versions")
class AgentRosterVersionsApi(Resource):
@console_ns.response(200, "Agent versions", console_ns.models[AgentConfigSnapshotListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID):
return dump_response(
AgentConfigSnapshotListResponse,
{"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))},
)
def get(self, agent_id: UUID):
_, tenant_id = current_account_with_tenant()
return {"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))}
@console_ns.route("/agents/<uuid:agent_id>/versions/<uuid:version_id>")
class AgentRosterVersionDetailApi(Resource):
@console_ns.response(200, "Agent version detail", console_ns.models[AgentConfigSnapshotDetailResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID, version_id: UUID):
return dump_response(
AgentConfigSnapshotDetailResponse,
_agent_roster_service().get_agent_version_detail(
tenant_id=tenant_id,
agent_id=str(agent_id),
version_id=str(version_id),
),
def get(self, agent_id: UUID, version_id: UUID):
_, tenant_id = current_account_with_tenant()
return _agent_roster_service().get_agent_version_detail(
tenant_id=tenant_id,
agent_id=str(agent_id),
version_id=str(version_id),
)

View File

@ -1,59 +0,0 @@
"""Agent App access & sharing endpoints (read-only workflow references).
An Agent App is backed by a roster Agent that workflow Agent nodes may also
reference. This exposes the read-only "Workflow access" surface from the PRD:
which workflow apps use this Agent, without leaking the workflows' internals.
"""
from flask_restx import Resource
from pydantic import Field
from controllers.common.schema import register_response_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.login import login_required
from models.model import App, AppMode
from services.agent.roster_service import AgentRosterService
class AgentReferencingWorkflowResponse(ResponseModel):
app_id: str
app_name: str
app_mode: str
workflow_id: str
node_ids: list[str] = Field(default_factory=list)
class AgentReferencingWorkflowsResponse(ResponseModel):
data: list[AgentReferencingWorkflowResponse] = Field(default_factory=list)
register_response_schema_models(console_ns, AgentReferencingWorkflowsResponse)
@console_ns.route("/apps/<uuid:app_id>/agent-referencing-workflows")
class AgentAppReferencingWorkflowsResource(Resource):
@console_ns.doc("list_agent_app_referencing_workflows")
@console_ns.doc(description="List workflow apps that reference this Agent App's bound Agent (read-only)")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(
200,
"Referencing workflows listed successfully",
console_ns.models[AgentReferencingWorkflowsResponse.__name__],
)
@console_ns.response(404, "App not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App):
workflows = AgentRosterService(db.session).list_workflows_referencing_app_agent(
tenant_id=tenant_id, app_id=app_model.id
)
return AgentReferencingWorkflowsResponse(
data=[AgentReferencingWorkflowResponse.model_validate(workflow) for workflow in workflows]
).model_dump(mode="json")

View File

@ -1,93 +0,0 @@
"""Agent App presentation-feature configuration endpoint.
The new Agent App type keeps model / prompt / tools in its bound Agent Soul, so
the legacy ``/model-config`` surface (which writes model, prompt and agent tool
config) is the wrong place to configure its app-level presentation features.
This endpoint exposes only the PRD "Misc Legacy" feature subset — conversation
opener, follow-up suggestions, citations, content moderation and speech — and
persists them onto the app's ``app_model_config`` without touching anything the
Soul owns.
"""
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from events.app_event import app_model_config_was_updated
from libs.helper import dump_response
from libs.login import login_required
from models import Account
from models.agent_config_entities import (
AgentFeatureToggleConfig,
AgentSensitiveWordAvoidanceFeatureConfig,
AgentSuggestedQuestionsAfterAnswerFeatureConfig,
AgentTextToSpeechFeatureConfig,
)
from models.model import App, AppMode
from services.agent_app_feature_service import AgentAppFeatureConfigService
class AgentAppFeaturesPayload(BaseModel):
"""Presentation features configurable on an Agent App.
All fields are optional; an omitted field is reset to its disabled/empty
default (the config form sends the full desired feature state on save).
"""
opening_statement: str | None = Field(default=None, description="Conversation opener shown before the first turn")
suggested_questions: list[str] | None = Field(
default=None, description="Preset questions shown alongside the opener"
)
suggested_questions_after_answer: AgentSuggestedQuestionsAfterAnswerFeatureConfig | None = Field(
default=None, description="Follow-up suggestions config, e.g. {'enabled': true}"
)
speech_to_text: AgentFeatureToggleConfig | None = Field(default=None, description="Speech-to-text config")
text_to_speech: AgentTextToSpeechFeatureConfig | None = Field(default=None, description="Text-to-speech config")
retriever_resource: AgentFeatureToggleConfig | None = Field(
default=None, description="Citations / attributions config, e.g. {'enabled': true}"
)
sensitive_word_avoidance: AgentSensitiveWordAvoidanceFeatureConfig | None = Field(
default=None, description="Content moderation config"
)
register_schema_models(console_ns, AgentAppFeaturesPayload)
register_response_schema_models(console_ns, SimpleResultResponse)
@console_ns.route("/apps/<uuid:app_id>/agent-features")
class AgentAppFeatureConfigResource(Resource):
@console_ns.doc("update_agent_app_features")
@console_ns.doc(description="Update an Agent App's presentation features (opener, follow-up, citations, ...)")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AgentAppFeaturesPayload.__name__])
@console_ns.response(200, "Features updated successfully", console_ns.models[SimpleResultResponse.__name__])
@console_ns.response(400, "Invalid configuration")
@console_ns.response(404, "App not found")
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
@with_current_user
def post(self, current_user: Account, app_model: App):
args = AgentAppFeaturesPayload.model_validate(console_ns.payload or {})
new_app_model_config = AgentAppFeatureConfigService.update_features(
app_model=app_model,
account=current_user,
config=args.model_dump(exclude_none=True),
)
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
return dump_response(SimpleResultResponse, {"result": "success"})

View File

@ -25,9 +25,6 @@ from controllers.console.wraps import (
enterprise_license_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
with_current_user_id,
)
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
@ -37,8 +34,8 @@ from extensions.ext_database import db
from fields.base import ResponseModel
from graphon.enums import WorkflowExecutionStatus
from libs.helper import build_icon_url, to_timestamp
from libs.login import login_required
from models import Account, App, DatasetPermissionEnum, Workflow
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
from services.app_dsl_service import AppDslService
from services.app_service import AppListParams, AppService, CreateAppParams
@ -58,7 +55,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"]
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
register_enum_models(console_ns, IconType)
@ -69,7 +66,7 @@ _TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
default="all", description="App mode filter"
)
name: str | None = Field(default=None, description="Filter by app name")
@ -118,9 +115,7 @@ def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str,
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"] = Field(
..., description="App mode"
)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: IconType | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@ -398,8 +393,6 @@ class AppDetailWithSite(AppDetail):
max_active_requests: int | None = None
deleted_tools: list[DeletedTool] = Field(default_factory=list)
site: Site | None = None
# For Agent App type: the roster Agent backing this app (None otherwise).
bound_agent_id: str | None = None
@computed_field(return_type=str | None) # type: ignore
@property
@ -474,11 +467,10 @@ class AppListApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@with_session(write=False)
@with_current_user_id
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user_id: str, session: Session):
def get(self):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
params = AppListParams(
page=args.page,
@ -491,7 +483,7 @@ class AppListApi(Resource):
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params)
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, params)
if not app_pagination:
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
return empty.model_dump(mode="json"), 200
@ -512,7 +504,7 @@ class AppListApi(Resource):
draft_trigger_app_ids: set[str] = set()
if workflow_capable_app_ids:
draft_workflows = (
session.execute(
db.session.execute(
select(Workflow).where(
Workflow.version == Workflow.VERSION_DRAFT,
Workflow.app_id.in_(workflow_capable_app_ids),
@ -551,10 +543,9 @@ class AppListApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
def post(self):
"""Create app"""
current_user, current_tenant_id = current_account_with_tenant()
args = CreateAppPayload.model_validate(console_ns.payload)
params = CreateAppParams(
name=args.name,
@ -657,10 +648,11 @@ class AppCopyApi(Resource):
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
@with_current_user
def post(self, current_user: Account, app_model: App):
def post(self, app_model: App):
"""Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = CopyAppPayload.model_validate(console_ns.payload or {})
with Session(db.engine, expire_on_commit=False) as session:
@ -739,8 +731,7 @@ class AppPublishToCreatorsPlatformApi(Resource):
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
@with_current_user_id
def post(self, current_user_id: str, app_model: App):
def post(self, app_model: App):
"""Publish app to Creators Platform"""
from configs import dify_config
from core.helper.creators import get_redirect_url, upload_dsl
@ -748,11 +739,13 @@ class AppPublishToCreatorsPlatformApi(Resource):
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
return {"error": "Creators Platform features are not enabled"}, 403
current_user, _ = current_account_with_tenant()
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
dsl_bytes = dsl_content.encode("utf-8")
claim_code = upload_dsl(dsl_bytes)
redirect_url = get_redirect_url(current_user_id, claim_code)
redirect_url = get_redirect_url(str(current_user.id), claim_code)
return {"redirect_url": redirect_url}

View File

@ -9,11 +9,9 @@ from controllers.console.wraps import (
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
with_current_user,
)
from extensions.ext_database import db
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.model import App
from services.app_dsl_service import AppDslService, Import
from services.enterprise.enterprise_service import EnterpriseService
@ -50,9 +48,9 @@ class AppImportApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
@with_current_user
def post(self, current_user: Account):
def post(self):
# Check user role first
current_user, _ = current_account_with_tenant()
args = AppImportPayload.model_validate(console_ns.payload)
# AppDslService performs internal commits for some creation paths, so use a plain
@ -99,9 +97,10 @@ class AppImportConfirmApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user
def post(self, current_user: Account, import_id: str):
def post(self, import_id: str):
# Check user role first
current_user, _ = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
# Confirm import

View File

@ -4,7 +4,7 @@ from typing import Any, Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.common.fields import SimpleResultResponse
@ -19,12 +19,7 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user_id,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
@ -46,24 +41,9 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
def _resolve_debugger_chat_streaming(
*, app_mode: AppMode, response_mode: str, response_mode_provided: bool = True
) -> bool:
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
if app_mode != AppMode.AGENT:
return response_mode != "blocking"
if response_mode_provided and response_mode == "blocking":
raise BadRequest("Agent App only supports streaming response mode.")
return True
class BaseMessagePayload(BaseModel):
inputs: dict[str, Any]
# Agent Apps (AppMode.AGENT) derive their model + prompt from the bound Agent
# Soul, so no override ``model_config`` is sent; chat / agent-chat / completion
# debugging still pass it. Optional here, required in practice by those modes
# downstream when their config is built from args.
model_config_data: dict[str, Any] = Field(default_factory=dict, alias="model_config")
model_config_data: dict[str, Any] = Field(..., alias="model_config")
files: list[Any] | None = Field(default=None, description="Uploaded files")
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
retriever_from: str = Field(default="dev", description="Retriever source")
@ -151,13 +131,14 @@ class CompletionMessageStopApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@with_current_user_id
def post(self, current_user_id: str, app_model: App, task_id: str):
def post(self, app_model: App, task_id: str):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.DEBUGGER,
user_id=current_user_id,
user_id=current_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
@ -176,20 +157,13 @@ class ChatMessageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required
def post(self, app_model: App):
raw_payload = console_ns.payload or {}
args_model = ChatMessagePayload.model_validate(raw_payload)
args_model = ChatMessagePayload.model_validate(console_ns.payload)
args = args_model.model_dump(exclude_none=True, by_alias=True)
streaming = _resolve_debugger_chat_streaming(
app_mode=AppMode.value_of(app_model.mode),
response_mode=args_model.response_mode,
response_mode_provided=isinstance(raw_payload, dict) and "response_mode" in raw_payload,
)
if AppMode.value_of(app_model.mode) == AppMode.AGENT:
args["response_mode"] = "streaming"
streaming = args_model.response_mode != "blocking"
args["auto_generate_name"] = False
external_trace_id = get_external_trace_id(request)
@ -237,14 +211,15 @@ class ChatMessageStopApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@with_current_user_id
def post(self, current_user_id: str, app_model: App, task_id: str):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model: App, task_id: str):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.DEBUGGER,
user_id=current_user_id,
user_id=current_user.id,
app_mode=AppMode.value_of(app_model.mode),
)

View File

@ -12,12 +12,7 @@ from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
@ -36,9 +31,8 @@ from fields.conversation_fields import (
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
)
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models.account import Account
from models.model import App, AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
@ -99,8 +93,8 @@ class CompletionConversationApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
@with_current_user
def get(self, current_user: Account, app_model: App):
def get(self, app_model: App):
current_user, _ = current_account_with_tenant()
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True))
query = sa.select(Conversation).where(
@ -171,11 +165,10 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
@with_current_user
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
def get(self, app_model: App, conversation_id: UUID):
conversation_id_str = str(conversation_id)
return ConversationMessageDetailResponse.model_validate(
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
_get_conversation(app_model, conversation_id_str), from_attributes=True
).model_dump(mode="json")
@console_ns.doc("delete_completion_conversation")
@ -189,8 +182,8 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
@with_current_user
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
def delete(self, app_model: App, conversation_id: UUID):
current_user, _ = current_account_with_tenant()
conversation_id_str = str(conversation_id)
try:
@ -212,10 +205,10 @@ class ChatConversationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@edit_permission_required
@with_current_user
def get(self, current_user: Account, app_model: App):
def get(self, app_model: App):
current_user, _ = current_account_with_tenant()
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True))
subquery = (
@ -323,13 +316,12 @@ class ChatConversationDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@edit_permission_required
@with_current_user
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
def get(self, app_model: App, conversation_id: UUID):
conversation_id_str = str(conversation_id)
return ConversationDetailResponse.model_validate(
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
_get_conversation(app_model, conversation_id_str), from_attributes=True
).model_dump(mode="json")
@console_ns.doc("delete_chat_conversation")
@ -340,11 +332,11 @@ class ChatConversationDetailApi(Resource):
@console_ns.response(404, "Conversation not found")
@setup_required
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
@edit_permission_required
@with_current_user
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
def delete(self, app_model: App, conversation_id: UUID):
current_user, _ = current_account_with_tenant()
conversation_id_str = str(conversation_id)
try:
@ -355,7 +347,8 @@ class ChatConversationDetailApi(Resource):
return "", 204
def _get_conversation(current_user: Account, app_model, conversation_id):
def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation = db.session.scalar(
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
)

View File

@ -2,7 +2,6 @@ from collections.abc import Sequence
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
@ -12,7 +11,6 @@ from controllers.console.app.error import (
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import with_session
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
from core.app.app_config.entities import ModelConfig
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
@ -21,6 +19,7 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
@ -159,8 +158,7 @@ class InstructionGenerateApi(Resource):
@login_required
@account_initialization_required
@with_current_tenant_id
@with_session(write=False)
def post(self, session: Session, current_tenant_id: str):
def post(self, current_tenant_id: str):
args = InstructionGeneratePayload.model_validate(console_ns.payload)
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
@ -170,10 +168,10 @@ class InstructionGenerateApi(Resource):
try:
# Generate from nothing for a workflow node
if (args.current in (code_template, "")) and args.node_id != "":
app = session.get(App, args.flow_id)
app = db.session.get(App, args.flow_id)
if not app:
return {"error": f"app {args.flow_id} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app, session=session)
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
return {"error": f"workflow {args.flow_id} not found"}, 400
nodes: Sequence = workflow.graph_dict["nodes"]

View File

@ -25,7 +25,6 @@ from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
@ -44,8 +43,7 @@ from fields.conversation_fields import (
from graphon.model_runtime.errors.invoke import InvokeError
from libs.helper import to_timestamp, uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import App, AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.errors.conversation import ConversationNotExistsError
@ -180,7 +178,7 @@ class ChatMessageListApi(Resource):
@login_required
@account_initialization_required
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@edit_permission_required
def get(self, app_model: App):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
@ -259,8 +257,9 @@ class MessageFeedbackApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def post(self, current_user: Account, app_model: App):
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
args = MessageFeedbackPayload.model_validate(console_ns.payload)
message_id = str(args.message_id)
@ -337,9 +336,9 @@ class MessageSuggestedQuestionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@with_current_user
def get(self, current_user: Account, app_model: App, message_id: UUID):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model: App, message_id: UUID):
current_user, _ = current_account_with_tenant()
message_id_str = str(message_id)
try:

View File

@ -8,20 +8,14 @@ from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.model import App, AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
@ -58,10 +52,9 @@ class ModelConfigResource(Resource):
@edit_permission_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
@with_current_user_id
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user_id: str, app_model: App):
def post(self, app_model: App):
"""Modify app model config"""
current_user, current_tenant_id = current_account_with_tenant()
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_tenant_id,
@ -71,8 +64,8 @@ class ModelConfigResource(Resource):
new_app_model_config = AppModelConfig(
app_id=app_model.id,
created_by=current_user_id,
updated_by=current_user_id,
created_by=current_user.id,
updated_by=current_user.id,
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
@ -97,7 +90,7 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user_id,
user_id=current_user.id,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_tenant_id,
@ -137,7 +130,7 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user_id,
user_id=current_user.id,
)
except Exception:
continue
@ -174,7 +167,7 @@ class ModelConfigResource(Resource):
db.session.flush()
app_model.app_model_config_id = new_app_model_config.id
app_model.updated_by = current_user_id
app_model.updated_by = current_user.id
app_model.updated_at = naive_utc_now()
db.session.commit()

View File

@ -14,14 +14,12 @@ from controllers.console.wraps import (
edit_permission_required,
is_admin_or_owner_required,
setup_required,
with_current_user,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import Site
from models.account import Account
from models.model import App
@ -87,9 +85,9 @@ class AppSite(Resource):
@edit_permission_required
@account_initialization_required
@get_app_model
@with_current_user
def post(self, current_user: Account, app_model: App):
def post(self, app_model: App):
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise NotFound
@ -136,8 +134,8 @@ class AppSiteAccessTokenReset(Resource):
@is_admin_or_owner_required
@account_initialization_required
@get_app_model
@with_current_user
def post(self, current_user: Account, app_model: App):
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:

View File

@ -8,14 +8,13 @@ from pydantic import BaseModel, Field, field_validator
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import convert_datetime_to_date
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import AppMode
from models.account import Account
from models.model import App
@ -49,8 +48,9 @@ class DailyMessageStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def get(self, account: Account, app_model: App):
def get(self, app_model: App):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -109,8 +109,9 @@ class DailyConversationStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def get(self, account: Account, app_model: App):
def get(self, app_model: App):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -168,8 +169,9 @@ class DailyTerminalsStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def get(self, account: Account, app_model: App):
def get(self, app_model: App):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -228,8 +230,9 @@ class DailyTokenCostStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def get(self, account: Account, app_model: App):
def get(self, app_model: App):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -290,9 +293,10 @@ class AverageSessionInteractionStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@with_current_user
def get(self, account: Account, app_model: App):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model: App):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("c.created_at")
@ -370,8 +374,9 @@ class UserSatisfactionRateStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def get(self, account: Account, app_model: App):
def get(self, app_model: App):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("m.created_at")
@ -439,8 +444,9 @@ class AverageResponseTimeStatistic(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@with_current_user
def get(self, account: Account, app_model: App):
def get(self, app_model: App):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -499,8 +505,8 @@ class TokensPerSecondStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def get(self, account: Account, app_model: App):
def get(self, app_model: App):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")

View File

@ -6,11 +6,10 @@ from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom
from models.model import App, AppMode
from repositories.factory import DifyAPIRepositoryFactory
@ -47,8 +46,9 @@ class WorkflowDailyRunsStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def get(self, account: Account, app_model: App):
def get(self, app_model: App):
account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None
@ -86,8 +86,9 @@ class WorkflowDailyTerminalsStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def get(self, account: Account, app_model: App):
def get(self, app_model: App):
account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None
@ -125,8 +126,9 @@ class WorkflowDailyTokenCostStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def get(self, account: Account, app_model: App):
def get(self, app_model: App):
account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None
@ -164,8 +166,9 @@ class WorkflowAverageAppInteractionStatistic(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@with_current_user
def get(self, account: Account, app_model: App):
def get(self, app_model: App):
account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None

View File

@ -32,11 +32,11 @@ from controllers.console.wraps import (
decrypt_password_field,
email_password_login_enabled,
setup_required,
with_current_user,
)
from events.tenant_event import tenant_was_created
from libs.helper import EmailStr, extract_remote_ip
from libs.helper import timezone as validate_timezone_string
from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
@ -46,7 +46,6 @@ from libs.token import (
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from models.account import Account
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
from services.billing_service import BillingService
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
@ -173,8 +172,9 @@ class LoginApi(Resource):
class LogoutApi(Resource):
@setup_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@with_current_user
def post(self, account: Account):
def post(self):
current_user, _ = current_account_with_tenant()
account = current_user
if isinstance(account, flask_login.AnonymousUserMixin):
response = make_response({"result": "success"})
else:

View File

@ -8,16 +8,9 @@ from werkzeug.exceptions import BadRequest
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
only_edition_cloud,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@ -39,9 +32,8 @@ class Subscription(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True))
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
@ -53,9 +45,8 @@ class Invoices(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_invoices(current_user.email, current_tenant_id)
@ -72,8 +63,9 @@ class PartnerTenants(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
@with_current_user
def put(self, current_user: Account, partner_key: str):
def put(self, partner_key: str):
current_user, _ = current_account_with_tenant()
try:
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
click_id = args.click_id

View File

@ -3,18 +3,11 @@ from flask_restx import Resource
from pydantic import BaseModel, Field
from libs.helper import extract_remote_ip
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from .. import console_ns
from ..wraps import (
account_initialization_required,
only_edition_cloud,
setup_required,
with_current_tenant_id,
with_current_user,
)
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
class ComplianceDownloadQuery(BaseModel):
@ -36,9 +29,8 @@ class ComplianceApi(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True))
ip_address = extract_remote_ip(request)

View File

@ -9,7 +9,7 @@ from uuid import UUID
import sqlalchemy as sa
from flask import request, send_file
from flask_restx import Resource
from flask_restx import Resource, marshal
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import asc, desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
@ -34,16 +34,14 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.document_fields import (
DocumentMetadataResponse,
DocumentResponse,
DocumentStatusListResponse,
DocumentStatusResponse,
normalize_enum,
document_fields,
document_status_fields,
document_with_segments_fields,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
from libs.datetime_utils import naive_utc_now
from libs.helper import dump_response, to_timestamp
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
@ -76,6 +74,12 @@ from ..wraps import (
logger = logging.getLogger(__name__)
def _normalize_enum(value: Any) -> Any:
if isinstance(value, str) or value is None:
return value
return getattr(value, "value", value)
class DatasetResponse(ResponseModel):
id: str
name: str
@ -89,7 +93,7 @@ class DatasetResponse(ResponseModel):
@field_validator("data_source_type", "indexing_technique", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> Any:
return normalize_enum(value)
return _normalize_enum(value)
@field_validator("created_at", mode="before")
@classmethod
@ -97,10 +101,61 @@ class DatasetResponse(ResponseModel):
return to_timestamp(value)
class DocumentMetadataResponse(ResponseModel):
id: str
name: str
type: str
value: str | None = None
class DocumentResponse(ResponseModel):
id: str
position: int | None = None
data_source_type: str | None = None
data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict")
data_source_detail_dict: Any = None
dataset_process_rule_id: str | None = None
name: str
created_from: str | None = None
created_by: str | None = None
created_at: int | None = None
tokens: int | None = None
indexing_status: str | None = None
error: str | None = None
enabled: bool | None = None
disabled_at: int | None = None
disabled_by: str | None = None
archived: bool | None = None
display_status: str | None = None
word_count: int | None = None
hit_count: int | None = None
doc_form: str | None = None
doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details")
summary_index_status: str | None = None
need_summary: bool | None = None
@field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> Any:
return _normalize_enum(value)
@field_validator("doc_metadata", mode="before")
@classmethod
def _normalize_doc_metadata(cls, value: Any) -> list[Any]:
if value is None:
return []
return value
@field_validator("created_at", "disabled_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class DocumentWithSegmentsResponse(DocumentResponse):
process_rule_dict: Any = None
completed_segments: int | None = Field(default=None, exclude_if=lambda value: value is None)
total_segments: int | None = Field(default=None, exclude_if=lambda value: value is None)
completed_segments: int | None = None
total_segments: int | None = None
class DatasetAndDocumentResponse(ResponseModel):
@ -135,14 +190,6 @@ class DocumentDatasetListParam(BaseModel):
fetch_val: str = Field("false", alias="fetch")
class DocumentWithSegmentsListResponse(ResponseModel):
data: list[DocumentWithSegmentsResponse]
has_more: bool
limit: int
total: int
page: int
register_schema_models(
console_ns,
KnowledgeConfig,
@ -153,19 +200,13 @@ register_schema_models(
GenerateSummaryPayload,
DocumentMetadataUpdatePayload,
DocumentBatchDownloadZipPayload,
)
register_response_schema_models(
console_ns,
SimpleResultMessageResponse,
SimpleResultResponse,
UrlResponse,
DatasetResponse,
DocumentMetadataResponse,
DocumentResponse,
DocumentWithSegmentsResponse,
DatasetAndDocumentResponse,
DocumentWithSegmentsListResponse,
)
register_response_schema_models(console_ns, SimpleResultMessageResponse, SimpleResultResponse, UrlResponse)
class DocumentResource(Resource):
@ -271,11 +312,7 @@ class DatasetDocumentListApi(Resource):
"status": "Filter documents by display status",
}
)
@console_ns.response(
200,
"Documents retrieved successfully",
console_ns.models[DocumentWithSegmentsListResponse.__name__],
)
@console_ns.response(200, "Documents retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@ -388,15 +425,18 @@ class DatasetDocumentListApi(Resource):
)
document.completed_segments = completed_segments
document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields)
else:
data = marshal(documents, document_fields)
response = {
"data": documents,
"data": data,
"has_more": len(documents) == limit,
"limit": limit,
"total": paginated_documents.total,
"page": page,
}
return dump_response(DocumentWithSegmentsListResponse, response)
return response
@setup_required
@login_required
@ -442,7 +482,9 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
return dump_response(DatasetAndDocumentResponse, {"dataset": dataset, "documents": documents, "batch": batch})
return DatasetAndDocumentResponse.model_validate(
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
).model_dump(mode="json")
@setup_required
@login_required
@ -525,7 +567,9 @@ class DatasetInitApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
return dump_response(DatasetAndDocumentResponse, {"dataset": dataset, "documents": documents, "batch": batch})
return DatasetAndDocumentResponse.model_validate(
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
).model_dump(mode="json")
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
@ -698,9 +742,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
@console_ns.route("/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
class DocumentBatchIndexingStatusApi(DocumentResource):
@console_ns.response(
200, "Indexing status retrieved successfully", console_ns.models[DocumentStatusListResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@ -743,8 +784,9 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(document_dict)
return dump_response(DocumentStatusListResponse, {"data": documents_status})
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status}
return data
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
@ -752,9 +794,7 @@ class DocumentIndexingStatusApi(DocumentResource):
@console_ns.doc("get_document_indexing_status")
@console_ns.doc(description="Get document indexing status")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.response(
200, "Indexing status retrieved successfully", console_ns.models[DocumentStatusResponse.__name__]
)
@console_ns.response(200, "Indexing status retrieved successfully")
@console_ns.response(404, "Document not found")
@setup_required
@login_required
@ -799,7 +839,7 @@ class DocumentIndexingStatusApi(DocumentResource):
"completed_segments": completed_segments,
"total_segments": total_segments,
}
return dump_response(DocumentStatusResponse, document_dict)
return marshal(document_dict, document_status_fields)
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
@ -1264,7 +1304,7 @@ class DocumentRenameApi(DocumentResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return dump_response(DocumentResponse, document)
return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json")
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")

View File

@ -1,10 +1,9 @@
import uuid
from typing import Literal
from typing import cast as type_cast
from uuid import UUID
from flask import request
from flask_restx import Resource
from flask_restx import Resource, marshal
from pydantic import BaseModel, Field
from sqlalchemy import String, case, cast, func, literal, or_, select
from sqlalchemy.dialects.postgresql import JSONB
@ -14,12 +13,7 @@ import services
from configs import dify_config
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import (
query_params_from_model,
query_params_from_request,
register_response_schema_models,
register_schema_models,
)
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import (
@ -40,17 +34,9 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.base import ResponseModel
from fields.segment_fields import (
ChildChunkDetailResponse,
ChildChunkListResponse,
ChildChunkResponse,
SegmentDetailResponse,
SegmentResponse,
segment_response_with_summary,
segment_responses_with_summaries,
)
from fields.segment_fields import child_chunk_fields, segment_fields
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import dump_response, escape_like_pattern
from libs.helper import escape_like_pattern
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
@ -58,10 +44,20 @@ from services.dataset_service import DatasetService, DocumentService, SegmentSer
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
from services.summary_index_service import SummaryIndexService
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
def _get_segment_with_summary(segment, dataset_id):
"""Helper function to marshal segment and add summary information."""
from services.summary_index_service import SummaryIndexService
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
# Query summary for this segment (only enabled summaries)
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
@ -71,16 +67,6 @@ class SegmentListQuery(BaseModel):
page: int = Field(default=1, ge=1)
class SegmentIdListQuery(BaseModel):
segment_id: list[str] = Field(default_factory=list, description="Segment IDs")
class ChildChunkListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
keyword: str | None = None
page: int = Field(default=1, ge=1)
class SegmentCreatePayload(BaseModel):
content: str
answer: str | None = None
@ -106,35 +92,13 @@ class SegmentBatchImportStatusResponse(ResponseModel):
job_status: str
class ConsoleSegmentListResponse(ResponseModel):
data: list[SegmentResponse]
limit: int
total: int
total_pages: int
page: int
class ChildChunkBatchUpdateResponse(ResponseModel):
data: list[ChildChunkResponse]
class ChildChunkBatchUpdatePayload(BaseModel):
chunks: list[ChildChunkUpdateArgs]
class SegmentDocParams:
DATASET_DOCUMENT = {"dataset_id": "Dataset ID", "document_id": "Document ID"}
DATASET_DOCUMENT_ACTION = {**DATASET_DOCUMENT, "action": "Action"}
DATASET_DOCUMENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Segment ID"}
DATASET_DOCUMENT_PARENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Parent segment ID"}
DATASET_DOCUMENT_CHILD_CHUNK = {**DATASET_DOCUMENT_PARENT_SEGMENT, "child_chunk_id": "Child chunk ID"}
register_schema_models(
console_ns,
SegmentListQuery,
SegmentIdListQuery,
ChildChunkListQuery,
SegmentCreatePayload,
SegmentUpdatePayload,
BatchImportPayload,
@ -143,24 +107,11 @@ register_schema_models(
ChildChunkBatchUpdatePayload,
ChildChunkUpdateArgs,
)
register_response_schema_models(
console_ns,
SegmentResponse,
ConsoleSegmentListResponse,
SegmentDetailResponse,
ChildChunkDetailResponse,
ChildChunkListResponse,
ChildChunkBatchUpdateResponse,
SegmentBatchImportStatusResponse,
SimpleResultResponse,
)
register_response_schema_models(console_ns, SegmentBatchImportStatusResponse, SimpleResultResponse)
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
class DatasetDocumentSegmentListApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@console_ns.doc(params=query_params_from_model(SegmentListQuery))
@console_ns.response(200, "Segments retrieved successfully", console_ns.models[ConsoleSegmentListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@ -183,7 +134,12 @@ class DatasetDocumentSegmentListApi(Resource):
if not document:
raise NotFound("Document not found.")
args = query_params_from_request(SegmentListQuery, list_fields=("status",))
args = SegmentListQuery.model_validate(
{
**request.args.to_dict(),
"status": request.args.getlist("status"),
}
)
page = args.page
limit = min(args.limit, 100)
@ -249,30 +205,38 @@ class DatasetDocumentSegmentListApi(Resource):
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
segment_list = list(segments.items)
segment_ids = [segment.id for segment in segment_list]
summaries: dict[str, str | None] = {}
# Query summaries for all segments in this page (batch query for efficiency)
segment_ids = [segment.id for segment in segments.items]
summaries = {}
if segment_ids:
from services.summary_index_service import SummaryIndexService
summary_records = SummaryIndexService.get_segments_summaries(
segment_ids=segment_ids, dataset_id=dataset_id_str
)
# Only include enabled summaries (already filtered by service)
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
# Add summary to each segment
segments_with_summary = []
for segment in segments.items:
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
segment_dict["summary"] = summaries.get(segment.id)
segments_with_summary.append(segment_dict)
response = {
"data": segment_responses_with_summaries(segment_list, summaries),
"data": segments_with_summary,
"limit": limit,
"total": segments.total,
"total_pages": segments.pages,
"page": page,
}
return dump_response(ConsoleSegmentListResponse, response), 200
return response, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
@console_ns.response(204, "Segments deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID):
current_user, _ = current_account_with_tenant()
@ -304,8 +268,6 @@ class DatasetDocumentSegmentListApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
class DatasetDocumentSegmentApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_ACTION)
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
@setup_required
@login_required
@account_initialization_required
@ -359,12 +321,11 @@ class DatasetDocumentSegmentApi(Resource):
SegmentService.update_segments_status(segment_ids, action, dataset, document)
except Exception as e:
raise InvalidActionError(str(e))
return dump_response(SimpleResultResponse, {"result": "success"}), 200
return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
class DatasetDocumentSegmentAddApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@setup_required
@login_required
@account_initialization_required
@ -372,7 +333,6 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
@console_ns.response(200, "Segment created successfully", console_ns.models[SegmentDetailResponse.__name__])
def post(self, dataset_id: UUID, document_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@ -412,25 +372,18 @@ class DatasetDocumentSegmentAddApi(Resource):
payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = type_cast(DocumentSegment, SegmentService.create_segment(payload_dict, document, dataset))
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
response = {
"data": segment_response_with_summary(segment, summary.summary_content if summary else None),
"doc_form": document.doc_form,
}
return dump_response(SegmentDetailResponse, response), 200
segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
class DatasetDocumentSegmentUpdateApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
@console_ns.response(200, "Segment updated successfully", console_ns.models[SegmentDetailResponse.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@ -487,18 +440,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
response = {
"data": segment_response_with_summary(segment, summary.summary_content if summary else None),
"doc_form": document.doc_form,
}
return dump_response(SegmentDetailResponse, response), 200
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@console_ns.response(204, "Segment deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@ -576,11 +523,11 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
try:
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = f"segment_batch_import_{job_id}"
indexing_cache_key = f"segment_batch_import_{str(job_id)}"
# send batch add segments task
redis_client.setnx(indexing_cache_key, "waiting")
batch_create_segment_to_index_task.delay(
job_id,
str(job_id),
upload_file_id,
dataset_id_str,
document_id_str,
@ -589,7 +536,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
)
except Exception as e:
return {"error": str(e)}, 500
return dump_response(SegmentBatchImportStatusResponse, {"job_id": job_id, "job_status": "waiting"}), 200
return {"job_id": job_id, "job_status": "waiting"}, 200
@console_ns.response(200, "Batch import status", console_ns.models[SegmentBatchImportStatusResponse.__name__])
@setup_required
@ -604,13 +551,11 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if cache_result is None:
raise ValueError("The job does not exist.")
response = {"job_id": job_id, "job_status": cache_result.decode()}
return dump_response(SegmentBatchImportStatusResponse, response), 200
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks")
class ChildChunkAddApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
@setup_required
@login_required
@account_initialization_required
@ -618,7 +563,6 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
@console_ns.response(200, "Child chunk created successfully", console_ns.models[ChildChunkDetailResponse.__name__])
def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@ -669,11 +613,8 @@ class ChildChunkAddApi(Resource):
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
@console_ns.doc(params=query_params_from_model(ChildChunkListQuery))
@console_ns.response(200, "Child chunks retrieved successfully", console_ns.models[ChildChunkListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@ -701,7 +642,13 @@ class ChildChunkAddApi(Resource):
)
if not segment:
raise NotFound("Segment not found.")
args = query_params_from_request(ChildChunkListQuery, use_defaults_for_malformed_ints=True)
args = SegmentListQuery.model_validate(
{
"limit": request.args.get("limit", default=20, type=int),
"keyword": request.args.get("keyword"),
"page": request.args.get("page", default=1, type=int),
}
)
page = args.page
limit = min(args.limit, 100)
@ -710,27 +657,19 @@ class ChildChunkAddApi(Resource):
child_chunks = SegmentService.get_child_chunks(
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
)
response = {
"data": child_chunks.items,
return {
"data": marshal(child_chunks.items, child_chunk_fields),
"total": child_chunks.total,
"total_pages": child_chunks.pages,
"page": page,
"limit": limit,
}
return dump_response(ChildChunkListResponse, response), 200
}, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
@console_ns.response(
200,
"Child chunks updated successfully",
console_ns.models[ChildChunkBatchUpdateResponse.__name__],
)
@console_ns.expect(console_ns.models[ChildChunkBatchUpdatePayload.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@ -768,7 +707,7 @@ class ChildChunkAddApi(Resource):
child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return dump_response(ChildChunkBatchUpdateResponse, {"data": child_chunks}), 200
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
@console_ns.route(
@ -779,7 +718,6 @@ class ChildChunkUpdateApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
@console_ns.response(204, "Child chunk deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@ -810,7 +748,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == child_chunk_id_str,
ChildChunk.id == str(child_chunk_id_str),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id_str,
@ -837,9 +775,7 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
@console_ns.response(200, "Child chunk updated successfully", console_ns.models[ChildChunkDetailResponse.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@ -869,7 +805,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == child_chunk_id_str,
ChildChunk.id == str(child_chunk_id_str),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id_str,
@ -891,4 +827,4 @@ class ChildChunkUpdateApi(Resource):
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
return {"data": marshal(child_chunk, child_chunk_fields)}, 200

View File

@ -7,20 +7,14 @@ from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import MetadataUpdatePayload
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import (
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
from libs.helper import dump_response
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation,
@ -49,8 +43,8 @@ class DatasetMetadataCreateApi(Resource):
@enterprise_license_required
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
def post(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
dataset_id_str = str(dataset_id)
@ -86,8 +80,8 @@ class DatasetMetadataApi(Resource):
@enterprise_license_required
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
@with_current_user
def patch(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
def patch(self, dataset_id: UUID, metadata_id: UUID):
current_user, _ = current_account_with_tenant()
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
name = payload.name
@ -106,8 +100,8 @@ class DatasetMetadataApi(Resource):
@account_initialization_required
@enterprise_license_required
@console_ns.response(204, "Metadata deleted successfully")
@with_current_user
def delete(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
def delete(self, dataset_id: UUID, metadata_id: UUID):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -143,8 +137,8 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@account_initialization_required
@enterprise_license_required
@console_ns.response(204, "Action completed successfully")
@with_current_user
def post(self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable"]):
def post(self, dataset_id: UUID, action: Literal["enable", "disable"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -171,8 +165,8 @@ class DocumentMetadataEditApi(Resource):
204,
"Documents metadata updated successfully",
)
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
def post(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@ -198,13 +198,12 @@ class DatasourceAuth(Resource):
def get(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
user, current_tenant_id = current_account_with_tenant()
_, current_tenant_id = current_account_with_tenant()
datasources = datasource_provider_service.list_datasource_credentials(
tenant_id=current_tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
user=user,
)
return {"result": datasources}, 200

View File

@ -11,13 +11,10 @@ from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
@ -38,10 +35,9 @@ class CreateRagPipelineDatasetApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
def post(self):
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
@ -83,10 +79,10 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.create_empty_rag_pipeline_dataset(

View File

@ -10,7 +10,6 @@ from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from extensions.ext_database import db
from fields.rag_pipeline_fields import (
@ -18,8 +17,7 @@ from fields.rag_pipeline_fields import (
pipeline_import_check_dependencies_fields,
pipeline_import_fields,
)
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.dataset import Pipeline
from services.entities.dsl_entities import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@ -64,9 +62,9 @@ class RagPipelineImportApi(Resource):
@edit_permission_required
@marshal_with(pipeline_import_model)
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
@with_current_user
def post(self, current_user: Account):
def post(self):
# Check user role first
current_user, _ = current_account_with_tenant()
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
# Use a plain Session so that caught exceptions inside the service
@ -107,8 +105,9 @@ class RagPipelineImportConfirmApi(Resource):
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_model)
@with_current_user
def post(self, current_user: Account, import_id: str):
def post(self, import_id: str):
current_user, _ = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
import_service = RagPipelineDslService(session)
account = current_user

View File

@ -18,7 +18,6 @@ from controllers.console.app.error import (
)
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user_id
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
@ -136,18 +135,20 @@ class CompletionApi(InstalledAppResource):
)
class CompletionStopApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@with_current_user_id
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
def post(self, installed_app: InstalledApp, task_id: str):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.EXPLORE,
user_id=current_user_id,
user_id=current_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
@ -214,8 +215,7 @@ class ChatApi(InstalledAppResource):
)
class ChatStopApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@with_current_user_id
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
def post(self, installed_app: InstalledApp, task_id: str):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
@ -223,10 +223,13 @@ class ChatStopApi(InstalledAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.EXPLORE,
user_id=current_user_id,
user_id=current_user.id,
app_mode=app_mode,
)

View File

@ -12,19 +12,14 @@ from controllers.common.fields import SimpleMessageResponse, SimpleResultMessage
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
from fields.base import ResponseModel
from graphon.file import helpers as file_helpers
from libs.datetime_utils import naive_utc_now
from libs.helper import to_timestamp
from libs.login import login_required
from models import Account, App, InstalledApp, RecommendedApp
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
from models.model import IconType
from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService
@ -136,10 +131,9 @@ class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
def get(self):
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
current_user, current_tenant_id = current_account_with_tenant()
if query.app_id:
installed_apps = db.session.scalars(
@ -218,8 +212,7 @@ class InstalledAppsListApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
@console_ns.response(200, "Success", console_ns.models[SimpleMessageResponse.__name__])
@with_current_tenant_id
def post(self, current_tenant_id: str):
def post(self):
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
recommended_app = db.session.scalar(
@ -228,6 +221,8 @@ class InstalledAppsListApi(Resource):
if recommended_app is None:
raise NotFound("Recommended app not found")
_, current_tenant_id = current_account_with_tenant()
app = db.session.get(App, payload.app_id)
if app is None:
@ -267,8 +262,8 @@ class InstalledAppApi(InstalledAppResource):
"""
@console_ns.response(204, "App uninstalled successfully")
@with_current_tenant_id
def delete(self, current_tenant_id: str, installed_app: InstalledApp):
def delete(self, installed_app: InstalledApp):
_, current_tenant_id = current_account_with_tenant()
if installed_app.app_owner_tenant_id == current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant")

View File

@ -2,36 +2,13 @@ from flask_restx import Resource
from werkzeug.exceptions import Unauthorized
from controllers.common.schema import register_response_schema_models
from fields.base import ResponseModel
from libs.helper import dump_response
from libs.login import current_user, login_required
from services.feature_service import (
FeatureModel,
FeatureService,
LimitationModel,
SystemFeatureModel,
)
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
from . import console_ns
from .wraps import account_initialization_required, cloud_utm_record, setup_required, with_current_tenant_id
class TrialModelsResponse(ResponseModel):
trial_models: list[str]
class AppDslVersionResponse(ResponseModel):
app_dsl_version: str
register_response_schema_models(
console_ns,
AppDslVersionResponse,
FeatureModel,
LimitationModel,
SystemFeatureModel,
TrialModelsResponse,
)
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
@console_ns.route("/features")
@ -77,43 +54,6 @@ class FeatureVectorSpaceApi(Resource):
return FeatureService.get_vector_space(current_tenant_id).model_dump()
@console_ns.route("/trial-models")
class TrialModelsApi(Resource):
@console_ns.doc("get_trial_models")
@console_ns.doc(description="Get hosted trial model provider configuration")
@console_ns.response(
200,
"Success",
console_ns.models[TrialModelsResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
def get(self):
"""Get hosted trial model provider configuration for model-provider pages."""
return dump_response(
TrialModelsResponse,
{"trial_models": FeatureService.get_trial_models()},
)
@console_ns.route("/app-dsl-version")
class AppDslVersionApi(Resource):
@console_ns.doc("get_app_dsl_version")
@console_ns.doc(description="Get current app DSL version")
@console_ns.response(
200,
"Success",
console_ns.models[AppDslVersionResponse.__name__],
)
def get(self):
"""Get current app DSL version for workflow clipboard compatibility."""
return dump_response(
AppDslVersionResponse,
{"app_dsl_version": FeatureService.get_app_dsl_version()},
)
@console_ns.route("/system-features")
class SystemFeatureApi(Resource):
@console_ns.doc("get_system_features")

View File

@ -12,15 +12,8 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
model_validate,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
@ -29,8 +22,8 @@ from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.login import login_required
from models import Account, App
from libs.login import current_account_with_tenant, login_required
from models import App
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
@ -40,8 +33,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
register_schema_models(console_ns, HumanInputFormSubmitPayload)
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
@ -54,8 +45,9 @@ class ConsoleHumanInputFormApi(Resource):
"""Console API for getting human input form definition."""
@staticmethod
def _ensure_console_access(form: Form, current_tenant_id: str) -> None:
"""Ensure a console form token resolves only inside the current tenant."""
def _ensure_console_access(form: Form):
_, current_tenant_id = current_account_with_tenant()
if form.tenant_id != current_tenant_id:
raise NotFoundError("App not found")
@ -67,8 +59,7 @@ class ConsoleHumanInputFormApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, current_tenant_id: str, form_token: str):
def get(self, form_token: str):
"""
Get human input form definition by form token.
@ -79,23 +70,13 @@ class ConsoleHumanInputFormApi(Resource):
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form, current_tenant_id)
self._ensure_console_access(form)
return _jsonify_form_definition(form)
@account_initialization_required
@login_required
@with_current_user
@with_current_tenant_id
@model_validate(HumanInputFormSubmitPayload)
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
def post(
self,
payload: HumanInputFormSubmitPayload,
current_tenant_id: str,
current_user: Account,
form_token: str,
):
def post(self, form_token: str):
"""
Submit human input form by form token.
@ -109,12 +90,15 @@ class ConsoleHumanInputFormApi(Resource):
"action": "Approve"
}
"""
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
current_user, _ = current_account_with_tenant()
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form, current_tenant_id)
self._ensure_console_access(form)
self._ensure_console_recipient_type(form)
recipient_type = form.recipient_type
# The type checker is not smart enought to validate the following invariant.
@ -138,9 +122,7 @@ class ConsoleWorkflowEventsApi(Resource):
@account_initialization_required
@login_required
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account, workflow_run_id: str):
def get(self, workflow_run_id: str):
"""
Get workflow execution events stream after resume.
@ -148,6 +130,8 @@ class ConsoleWorkflowEventsApi(Resource):
Returns Server-Sent Events stream.
"""
user, tenant_id = current_account_with_tenant()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(

View File

@ -13,7 +13,7 @@ from controllers.common.errors import (
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import with_current_user
from core.file import remote_fetcher
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from graphon.file import helpers as file_helpers
@ -36,9 +36,9 @@ class GetRemoteFileInfo(Resource):
@login_required
def get(self, url: str):
decoded_url = helpers.decode_remote_url(url, request.query_string)
resp = remote_fetcher.make_request("HEAD", decoded_url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
resp = remote_fetcher.make_request("GET", decoded_url, timeout=3)
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
return RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
@ -58,9 +58,9 @@ class RemoteFileUpload(Resource):
# Try to fetch remote file metadata/content first
try:
resp = remote_fetcher.make_request("HEAD", url=url)
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = remote_fetcher.make_request("GET", url=url, timeout=3, follow_redirects=True)
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
# Normalize into a user-friendly error message expected by tests
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
@ -74,7 +74,7 @@ class RemoteFileUpload(Resource):
raise FileTooLargeError()
# Load content if needed
content = resp.content if resp.request.method == "GET" else remote_fetcher.make_request("GET", url).content
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
upload_file = FileService(db.engine).upload_file(

View File

@ -18,7 +18,7 @@ from controllers.common.fields import (
SimpleResultResponse,
VerificationTokenResponse,
)
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
@ -48,7 +48,7 @@ from fields.base import ResponseModel
from fields.member_fields import Account as AccountResponse
from graphon.file import helpers as file_helpers
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, dump_response, extract_remote_ip, timezone, to_timestamp
from libs.helper import EmailStr, extract_remote_ip, timezone, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from models.account import AccountStatus, InvitationCodeStatus
@ -173,6 +173,7 @@ class CheckEmailUniquePayload(BaseModel):
register_schema_models(
console_ns,
AccountResponse,
AccountInitPayload,
AccountNamePayload,
AccountAvatarPayload,
@ -244,7 +245,6 @@ register_schema_models(
)
register_response_schema_models(
console_ns,
AccountResponse,
AvatarUrlResponse,
SimpleResultDataResponse,
SimpleResultResponse,
@ -329,9 +329,9 @@ class AccountNameApi(Resource):
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@console_ns.expect(console_ns.models[AccountAvatarQuery.__name__])
@console_ns.doc("get_account_avatar")
@console_ns.doc(description="Get account avatar url")
@console_ns.doc(params=query_params_from_model(AccountAvatarQuery))
@console_ns.response(200, "Success", console_ns.models[AvatarUrlResponse.__name__])
@setup_required
@login_required
@ -342,7 +342,7 @@ class AccountAvatarApi(Resource):
avatar = args.avatar
if avatar.startswith(("http://", "https://")):
return dump_response(AvatarUrlResponse, {"avatar_url": avatar})
return {"avatar_url": avatar}
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == avatar).limit(1))
if upload_file is None:
@ -355,7 +355,7 @@ class AccountAvatarApi(Resource):
raise NotFound("Avatar file not found")
avatar_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
return dump_response(AvatarUrlResponse, {"avatar_url": avatar_url})
return {"avatar_url": avatar_url}
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
@setup_required

View File

@ -1,15 +1,9 @@
from flask_restx import Resource, fields
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, setup_required
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from services.agent_service import AgentService
@ -25,10 +19,14 @@ class AgentProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
return jsonable_encoder(AgentService.list_agent_providers(current_user.id, current_tenant_id))
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
user_id = user.id
tenant_id = current_tenant_id
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
@ -44,7 +42,6 @@ class AgentProviderApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, provider_name: str):
def get(self, provider_name: str):
current_user, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))

View File

@ -4,16 +4,11 @@ from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, setup_required
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import login_required
from models import Account, TenantAccountRole
from libs.login import current_account_with_tenant, login_required
from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
@ -34,9 +29,8 @@ class LoadBalancingCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, provider: str):
def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
@ -78,9 +72,8 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, provider: str, config_id: str):
def post(self, provider: str, config_id: str):
current_user, current_tenant_id = current_account_with_tenant()
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()

View File

@ -25,13 +25,12 @@ from controllers.console.wraps import (
account_initialization_required,
is_allow_transfer_owner,
setup_required,
with_current_user,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.member_fields import AccountWithRole, AccountWithRoleList
from libs.helper import extract_remote_ip
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountJoin, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
@ -137,8 +136,8 @@ class MemberListApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
@with_current_user
def get(self, current_user: Account):
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
@ -155,8 +154,7 @@ class MemberInviteEmailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def post(self, current_user: Account):
def post(self):
payload = console_ns.payload or {}
args = MemberInvitePayload.model_validate(payload)
@ -165,6 +163,7 @@ class MemberInviteEmailApi(Resource):
interface_language = args.language
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
@ -224,8 +223,8 @@ class MemberCancelInviteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def delete(self, current_user: Account, member_id: UUID):
def delete(self, member_id: UUID):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.get(Account, str(member_id))
@ -257,14 +256,14 @@ class MemberUpdateRoleApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def put(self, current_user: Account, member_id: UUID):
def put(self, member_id: UUID):
payload = console_ns.payload or {}
args = MemberRoleUpdatePayload.model_validate(payload)
new_role = args.role
if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not _is_role_enabled(new_role, current_user.current_tenant.id):
@ -298,8 +297,8 @@ class DatasetOperatorMemberListApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
@with_current_user
def get(self, current_user: Account):
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
@ -318,13 +317,13 @@ class SendOwnerTransferEmailApi(Resource):
@login_required
@account_initialization_required
@is_allow_transfer_owner
@with_current_user
def post(self, current_user: Account):
def post(self):
payload = console_ns.payload or {}
args = OwnerTransferEmailPayload.model_validate(payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
current_user, _ = current_account_with_tenant()
# check if the current user is the owner of the workspace
if not current_user.current_tenant:
raise ValueError("No current tenant")
@ -356,11 +355,11 @@ class OwnerTransferCheckApi(Resource):
@login_required
@account_initialization_required
@is_allow_transfer_owner
@with_current_user
def post(self, current_user: Account):
def post(self):
payload = console_ns.payload or {}
args = OwnerTransferCheckPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
@ -400,12 +399,12 @@ class OwnerTransfer(Resource):
@login_required
@account_initialization_required
@is_allow_transfer_owner
@with_current_user
def post(self, current_user: Account, member_id: UUID):
def post(self, member_id: UUID):
payload = console_ns.payload or {}
args = OwnerTransferPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):

View File

@ -8,19 +8,12 @@ from pydantic import BaseModel, Field, field_validator
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import uuid_value
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
@ -102,8 +95,10 @@ class ModelProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str):
def get(self):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
payload = request.args.to_dict(flat=True)
args = ParserModelList.model_validate(payload)
@ -119,8 +114,9 @@ class ModelProviderCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
def get(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
payload = request.args.to_dict(flat=True)
args = ParserCredentialId.model_validate(payload)
@ -137,8 +133,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {}
args = ParserCredentialCreate.model_validate(payload)
@ -161,8 +157,9 @@ class ModelProviderCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def put(self, current_tenant_id: str, provider: str):
def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {}
args = ParserCredentialUpdate.model_validate(payload)
@ -187,8 +184,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def delete(self, current_tenant_id: str, provider: str):
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {}
args = ParserCredentialDelete.model_validate(payload)
@ -208,8 +205,8 @@ class ModelProviderCredentialSwitchApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {}
args = ParserCredentialSwitch.model_validate(payload)
@ -228,8 +225,8 @@ class ModelProviderValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {}
args = ParserCredentialValidate.model_validate(payload)
@ -283,8 +280,11 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
payload = console_ns.payload or {}
args = ParserPreferredProviderType.model_validate(payload)
@ -301,11 +301,10 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, provider: str):
def get(self, provider: str):
if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid")
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
data = BillingService.get_model_provider_payment_link(
provider_name=provider,

View File

@ -13,14 +13,12 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import uuid_value
from libs.login import login_required
from models import Account
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
@ -195,7 +193,7 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
def get(self, tenant_id: str, provider):
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
@ -271,9 +269,8 @@ class ModelProviderModelCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account, provider: str):
def get(self, tenant_id: str, provider: str):
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True))
model_provider_service = ModelProviderService()
@ -295,13 +292,9 @@ class ModelProviderModelCredentialApi(Resource):
)
if args.config_from == "predefined-model":
# Only the predefined-model branch needs visibility filtering by user.
# The account is injected once by the handler and only passed into the
# service branch that needs user-scoped credential visibility.
available_credentials = model_provider_service.get_provider_available_credentials(
tenant_id=tenant_id,
provider=provider,
user=user,
)
else:
available_credentials = model_provider_service.get_provider_model_available_credentials(

View File

@ -69,7 +69,6 @@ class BuiltinToolAddPayload(BaseModel):
credentials: dict[str, Any]
name: str | None = Field(default=None, max_length=30)
type: CredentialType
visibility: str | None = None
class BuiltinToolUpdatePayload(BaseModel):
@ -278,7 +277,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
@ -294,7 +293,7 @@ class ToolBuiltinProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@ -307,7 +306,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
def post(self, provider):
_, tenant_id = current_account_with_tenant()
payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {})
@ -325,7 +324,7 @@ class ToolBuiltinProviderAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
def post(self, provider):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@ -339,7 +338,6 @@ class ToolBuiltinProviderAddApi(Resource):
credentials=payload.credentials,
name=payload.name,
api_type=CredentialType.of(payload.type),
visibility=payload.visibility,
)
@ -350,7 +348,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
def post(self, provider):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@ -372,20 +370,13 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
user, tenant_id = current_account_with_tenant()
# Optional list of credential IDs to include even if visibility would hide them
# (used when a workflow/agent node still references another member's only_me credential).
include_credential_ids = request.args.getlist("include_credential_ids") or [
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
]
def get(self, provider):
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credentials(
tenant_id=tenant_id,
provider_name=provider,
user=user,
include_credential_ids=include_credential_ids or None,
)
)
@ -393,7 +384,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/icon")
class ToolBuiltinProviderIconApi(Resource):
@setup_required
def get(self, provider: str):
def get(self, provider):
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@ -793,7 +784,7 @@ class ToolPluginOAuthApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
@ -831,7 +822,7 @@ class ToolPluginOAuthApi(Resource):
@console_ns.route("/oauth/plugin/<path:provider>/tool/callback")
class ToolOAuthCallback(Resource):
@setup_required
def get(self, provider: str):
def get(self, provider):
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
@ -868,7 +859,7 @@ class ToolOAuthCallback(Resource):
if not credentials:
raise Exception("the plugin credentials failed")
# add credentials to database — OAuth tokens default to only_me since they're personal
# add credentials to database
BuiltinToolManageService.add_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
@ -876,7 +867,6 @@ class ToolOAuthCallback(Resource):
credentials=dict(credentials),
expires_at=expires_at,
api_type=CredentialType.OAUTH2,
visibility="only_me",
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
@ -888,7 +878,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
def post(self, provider):
_, current_tenant_id = current_account_with_tenant()
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
return BuiltinToolManageService.set_default_provider(
@ -920,7 +910,7 @@ class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
@ -929,7 +919,7 @@ class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, provider: str):
def delete(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
@ -941,7 +931,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
@ -955,18 +945,13 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
user, tenant_id = current_account_with_tenant()
include_credential_ids = request.args.getlist("include_credential_ids") or [
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
]
def get(self, provider):
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
tenant_id=tenant_id,
provider=provider,
user=user,
include_credential_ids=include_credential_ids or None,
)
)
@ -1166,7 +1151,7 @@ class ToolMCPDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id: str):
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
@ -1195,7 +1180,7 @@ class ToolMCPUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id: str):
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)

View File

@ -77,7 +77,7 @@ class TriggerProviderIconApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
@ -103,7 +103,7 @@ class TriggerProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
"""Get info for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -119,18 +119,15 @@ class TriggerSubscriptionListApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
"""List all trigger subscriptions for the current tenant's provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
return jsonable_encoder(
TriggerProviderService.list_trigger_provider_subscriptions(
tenant_id=user.current_tenant_id,
provider_id=TriggerProviderID(provider),
user=user,
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
)
)
except ValueError as e:
@ -149,7 +146,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider: str):
def post(self, provider):
"""Add a new subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -178,7 +175,7 @@ class TriggerSubscriptionBuilderGetApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider: str, subscription_builder_id: str):
def get(self, provider, subscription_builder_id):
"""Get a subscription instance for a trigger provider"""
return jsonable_encoder(
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
@ -194,7 +191,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider: str, subscription_builder_id: str):
def post(self, provider, subscription_builder_id):
"""Verify and update a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -226,7 +223,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider: str, subscription_builder_id: str):
def post(self, provider, subscription_builder_id):
"""Update a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -260,7 +257,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider: str, subscription_builder_id: str):
def get(self, provider, subscription_builder_id):
"""Get the request logs for a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -283,7 +280,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider: str, subscription_builder_id: str):
def post(self, provider, subscription_builder_id):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -407,7 +404,7 @@ class TriggerOAuthAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
"""Initiate OAuth authorization flow for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -489,7 +486,7 @@ class TriggerOAuthAuthorizeApi(Resource):
@console_ns.route("/oauth/plugin/<path:provider>/trigger/callback")
class TriggerOAuthCallbackApi(Resource):
@setup_required
def get(self, provider: str):
def get(self, provider):
"""Handle OAuth callback for trigger provider"""
context_id = request.cookies.get("context_id")
if not context_id:
@ -557,7 +554,7 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
"""Get OAuth client configuration for a provider"""
user = current_user
assert user.current_tenant_id is not None
@ -603,7 +600,7 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
def post(self, provider):
"""Configure custom OAuth client for a provider"""
user = current_user
assert user.current_tenant_id is not None
@ -629,7 +626,7 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
def delete(self, provider):
"""Remove custom OAuth client configuration"""
user = current_user
assert user.current_tenant_id is not None
@ -657,7 +654,7 @@ class TriggerSubscriptionVerifyApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider: str, subscription_id: str):
def post(self, provider, subscription_id):
"""Verify credentials for an existing subscription (edit mode only)"""
user = current_user
assert user.current_tenant_id is not None

View File

@ -29,7 +29,7 @@ from controllers.console.wraps import (
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import TimestampField, dump_response, to_timestamp
from libs.helper import TimestampField, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
from services.account_service import TenantService
@ -56,11 +56,6 @@ class WorkspaceCustomConfigPayload(BaseModel):
replace_webapp_logo: str | None = None
class WorkspaceCustomConfigResponse(ResponseModel):
remove_webapp_brand: bool | None = None
replace_webapp_logo: str | None = None
class WorkspaceInfoPayload(BaseModel):
name: str
@ -74,7 +69,7 @@ class TenantInfoResponse(ResponseModel):
role: str | None = None
in_trial: bool | None = None
trial_end_reason: str | None = None
custom_config: WorkspaceCustomConfigResponse | None = None
custom_config: dict | None = None
trial_credits: int | None = None
trial_credits_used: int | None = None
next_credit_reset_date: int | None = None
@ -106,13 +101,9 @@ register_schema_models(
SwitchWorkspacePayload,
WorkspaceCustomConfigPayload,
WorkspaceInfoPayload,
)
register_response_schema_models(
console_ns,
TenantInfoResponse,
WorkspaceCustomConfigResponse,
WorkspacePermissionResponse,
)
register_response_schema_models(console_ns, WorkspacePermissionResponse)
provider_fields = {
"provider_name": fields.String,
@ -247,7 +238,13 @@ class TenantApi(Resource):
else:
raise Unauthorized("workspace is archived")
return dump_response(TenantInfoResponse, WorkspaceService.get_tenant_info(tenant)), 200
return (
TenantInfoResponse.model_validate(
WorkspaceService.get_tenant_info(tenant),
from_attributes=True,
).model_dump(mode="json"),
200,
)
@console_ns.route("/workspaces/switch")

View File

@ -7,9 +7,7 @@ from functools import wraps
from typing import Concatenate
from flask import abort, request
from pydantic import BaseModel, ValidationError
from sqlalchemy import select
from werkzeug.exceptions import UnprocessableEntity
from configs import dify_config
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
@ -514,79 +512,9 @@ def with_current_tenant_id[T, **P, R](
def with_current_user[T, **P, R](
view: Callable[Concatenate[T, Account, P], R],
) -> Callable[Concatenate[T, P], R]:
"""Inject the current authenticated Account into the handler as the first argument after self.
Usage::
class MyResource(Resource):
@login_required
@with_current_user
def get(self, current_user: Account):
...
"""
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
current_user, _ = current_account_with_tenant()
return view(self, current_user, *args, **kwargs)
return decorated
def with_current_user_id[T, **P, R](
view: Callable[Concatenate[T, str, P], R],
) -> Callable[Concatenate[T, P], R]:
"""Inject the current authenticated user's ID (as a string) into the handler.
Use this when the handler only needs the user ID and not the full Account object.
Usage::
class MyResource(Resource):
@login_required
@with_current_user_id
def get(self, current_user_id: str):
...
"""
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
current_user, _ = current_account_with_tenant()
return view(self, str(current_user.id), *args, **kwargs)
return decorated
def model_validate[T, M: BaseModel, **P, R](
model: type[M],
) -> Callable[
[Callable[Concatenate[T, M, P], R]],
Callable[Concatenate[T, P], R],
]:
"""Validate request data and inject the model instance as the first arg after self.
Source is determined by HTTP method:
GET/DELETE -> request.args
POST/PUT/PATCH -> JSON body
"""
def decorator(
view: Callable[Concatenate[T, M, P], R],
) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
if request.method in ("GET", "DELETE"):
raw = request.args.to_dict(flat=True)
else:
raw = request.get_json(silent=True) or {}
try:
validated = model.model_validate(raw)
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
return view(self, validated, *args, **kwargs)
return wrapper
return decorator

View File

@ -45,15 +45,6 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
.limit(1)
)
else:
# Try id first (preserves the original "explicit end-user
# id → that specific user" semantics for callers that pass
# a known EndUser.id). Fall back to session_id so daemon-
# supplied session UUIDs dedup against the row created on
# the first Reverse Invocation call — without this, an
# id-only lookup never matched (create writes user_id to
# session_id, id is auto-generated) and a fresh EndUser
# was created per call, breaking multi-turn chat
# continuation (see #36736).
user_model = session.scalar(
select(EndUser)
.where(
@ -62,15 +53,6 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
)
.limit(1)
)
if user_model is None:
user_model = session.scalar(
select(EndUser)
.where(
EndUser.session_id == user_id,
EndUser.tenant_id == tenant_id,
)
.limit(1)
)
if not user_model:
user_model = EndUser(

View File

@ -7,7 +7,7 @@ from hmac import new as hmac_new
from flask import abort, request
from configs import dify_config
from core.db.session_factory import session_factory
from extensions.ext_database import db
from models.model import EndUser
@ -44,8 +44,6 @@ def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""Inject an EndUser for valid inner API HMAC auth, otherwise pass the request through unchanged."""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.INNER_API:
@ -74,9 +72,9 @@ def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P,
if signature_base64 != token:
return view(*args, **kwargs)
with session_factory.create_session() as session:
kwargs["user"] = session.get(EndUser, user_id)
return view(*args, **kwargs)
kwargs["user"] = db.session.get(EndUser, user_id)
return view(*args, **kwargs)
return decorated

View File

@ -147,7 +147,7 @@ class AppDescribeApi(AppReadResource):
class AppListApi(Resource):
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
@auth_router.guard_workspace(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, *, auth_data: AuthData):
try:
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))

View File

@ -1,13 +1,11 @@
from __future__ import annotations
from controllers.openapi.auth.conditions import (
EDITION_CE,
EDITION_EE,
HAS_ALLOWED_ROLES,
LOADED_APP_IS_PRIVATE,
PATH_HAS_APP_ID,
WEBAPP_AUTH_ENABLED,
WORKSPACE_MEMBERSHIP_REQUIRED,
WORKSPACE_SCOPED,
)
from controllers.openapi.auth.data import Edition
from controllers.openapi.auth.flow import When
@ -17,18 +15,14 @@ from controllers.openapi.auth.prepare import (
load_app,
load_app_access_mode,
load_tenant,
load_tenant_from_request,
load_workspace_role,
resolve_external_user,
)
from controllers.openapi.auth.verify import (
check_acl,
check_app_api_enabled,
check_app_access,
check_membership,
check_private_app_permission,
check_scope,
check_workspace_member,
check_workspace_mismatch,
check_workspace_role,
)
from libs.oauth_bearer import TokenType
@ -36,17 +30,13 @@ account_pipeline = AuthPipeline(
prepare=[
When(PATH_HAS_APP_ID, then=load_app),
When(PATH_HAS_APP_ID, then=load_tenant),
When(WORKSPACE_MEMBERSHIP_REQUIRED, then=load_tenant_from_request),
load_account,
When(WORKSPACE_SCOPED, then=load_workspace_role),
load_account, # all tokens here are account tokens
When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode),
],
auth=[
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
check_scope,
When(WORKSPACE_SCOPED, then=check_workspace_member),
When(PATH_HAS_APP_ID, then=check_workspace_mismatch),
When(HAS_ALLOWED_ROLES, then=check_workspace_role),
When(EDITION_CE & PATH_HAS_APP_ID, then=check_membership),
When(EDITION_EE & PATH_HAS_APP_ID & ~WEBAPP_AUTH_ENABLED, then=check_app_access),
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
],
@ -60,7 +50,6 @@ external_sso_pipeline = AuthPipeline(
When(PATH_HAS_APP_ID, then=load_app_access_mode),
],
auth=[
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
check_scope,
When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl),
When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission),

View File

@ -50,11 +50,4 @@ EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
WORKSPACE_MEMBERSHIP_REQUIRED = request_cond(lambda ctx: ctx.workspace_membership)
HAS_ALLOWED_ROLES = request_cond(lambda ctx: ctx.allowed_roles is not None)
# Caller must belong to the resolved tenant: either an app-scoped path (tenant
# from the app) or an explicit workspace-membership path (tenant from request).
WORKSPACE_SCOPED = PATH_HAS_APP_ID | WORKSPACE_MEMBERSHIP_REQUIRED
LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE)

View File

@ -9,7 +9,7 @@ from werkzeug.exceptions import InternalServerError
from configs import dify_config
from libs.oauth_bearer import Scope, TokenType
from models.account import Account, Tenant, TenantAccountRole
from models.account import Account, Tenant
from models.model import App, EndUser
from services.enterprise.enterprise_service import WebAppAccessMode
@ -41,8 +41,6 @@ class RequestContext(BaseModel):
token_type: TokenType
scope: Scope | None = None
path_params: dict[str, str]
workspace_membership: bool = False
allowed_roles: frozenset[TenantAccountRole] | None = None
class AuthData(BaseModel):
@ -58,14 +56,10 @@ class AuthData(BaseModel):
external_identity: ExternalIdentity | None = None
path_params: dict[str, str] = Field(default_factory=dict)
allowed_roles: frozenset[TenantAccountRole] | None = None
app: App | None = None
tenant: Tenant | None = None
app_access_mode: WebAppAccessMode | None = None
tenant_role: TenantAccountRole | None = None
caller: Account | EndUser | None = None
caller_kind: Literal["account", "end_user"] | None = None

View File

@ -34,7 +34,6 @@ from libs.oauth_bearer import (
reset_auth_ctx,
set_auth_ctx,
)
from models.account import TenantAccountRole
from services.feature_service import FeatureService, LicenseStatus
@ -57,15 +56,11 @@ class AuthPipeline:
view: Callable,
*,
scope: Scope | None,
workspace_membership: bool = False,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Any:
req_ctx = RequestContext(
token_type=identity.token_type,
scope=scope,
path_params=dict(request.view_args or {}),
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
data = AuthData(
@ -76,7 +71,6 @@ class AuthPipeline:
scopes=frozenset(identity.scopes),
tenants=dict(identity.verified_tenants),
required_scope=scope,
allowed_roles=allowed_roles,
path_params=dict(req_ctx.path_params),
external_identity=(
ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer)
@ -127,41 +121,6 @@ class PipelineRouter:
scope: Scope | None = None,
allowed_token_types: frozenset[TokenType] | None = None,
edition: frozenset[Edition] | None = None,
workspace_membership: bool = False,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Callable:
return self._make_decorator(
scope=scope,
allowed_token_types=allowed_token_types,
edition=edition,
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
def guard_workspace(
self,
*,
scope: Scope | None = None,
allowed_token_types: frozenset[TokenType] | None = None,
edition: frozenset[Edition] | None = None,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Callable:
return self._make_decorator(
scope=scope,
allowed_token_types=allowed_token_types,
edition=edition,
workspace_membership=True,
allowed_roles=allowed_roles,
)
def _make_decorator(
self,
*,
scope: Scope | None,
allowed_token_types: frozenset[TokenType] | None,
edition: frozenset[Edition] | None,
workspace_membership: bool,
allowed_roles: frozenset[TenantAccountRole] | None,
) -> Callable:
def decorator(view: Callable) -> Callable:
@wraps(view)
@ -173,8 +132,6 @@ class PipelineRouter:
scope=scope,
allowed_token_types=allowed_token_types,
edition=edition,
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
return decorated
@ -190,8 +147,6 @@ class PipelineRouter:
scope: Scope | None,
allowed_token_types: frozenset[TokenType] | None,
edition: frozenset[Edition] | None,
workspace_membership: bool = False,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Any:
# 404 not 403 — this edition doesn't expose the feature at all
if edition is not None and current_edition() not in edition:
@ -227,15 +182,7 @@ class PipelineRouter:
if not license_checked and Edition.EE in route.required_edition:
_check_license()
return route.pipeline._run(
identity,
args,
kwargs,
view,
scope=scope,
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
return route.pipeline._run(identity, args, kwargs, view, scope=scope)
def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool:

View File

@ -1,8 +1,5 @@
from __future__ import annotations
import uuid
from flask import request
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound, Unauthorized
from controllers.openapi.auth.data import AuthData
@ -16,18 +13,16 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppAcce
def load_app(data: AuthData) -> None:
if data.app is not None:
return
app_id = data.path_params["app_id"]
app = AppService.get_app_by_id(db.session, app_id)
if not app or app.status != "normal":
raise NotFound("app not found")
if not app.enable_api:
raise Forbidden("service_api_disabled")
data.app = app
def load_tenant(data: AuthData) -> None:
if data.tenant is not None:
return
if data.app is None:
raise InternalServerError("pipeline_invariant_violated: app not loaded before load_tenant")
tenant = TenantService.get_tenant_by_id(db.session, str(data.app.tenant_id))
@ -36,25 +31,7 @@ def load_tenant(data: AuthData) -> None:
data.tenant = tenant
def load_tenant_from_request(data: AuthData) -> None:
if data.tenant is not None:
return
workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
if not workspace_id:
raise NotFound("workspace not found")
try:
uuid.UUID(workspace_id)
except ValueError:
raise NotFound("workspace not found")
tenant = TenantService.get_tenant_by_id(db.session, workspace_id)
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
raise NotFound("workspace not found")
data.tenant = tenant
def load_account(data: AuthData) -> None:
if data.caller is not None:
return
account = AccountService.get_account_by_id(db.session, str(data.account_id))
if account is None:
raise Unauthorized("account not found")
@ -64,19 +41,6 @@ def load_account(data: AuthData) -> None:
data.caller_kind = "account"
def load_workspace_role(data: AuthData) -> None:
if data.tenant_role is not None:
return
if data.tenant is None or data.account_id is None:
return
if data.caller is not None and getattr(data.caller, "status", None) != "active":
return
role = TenantService.get_account_role_in_tenant(db.session, str(data.account_id), str(data.tenant.id))
if role is None:
return
data.tenant_role = role
def resolve_external_user(data: AuthData) -> None:
if data.tenant is None or data.app is None or data.external_identity is None:
raise Unauthorized("missing context for external user resolution")

View File

@ -0,0 +1,77 @@
"""Workspace role gate.
Layered on top of `validate_bearer` + `accept_subjects(SubjectType.ACCOUNT)`
for routes whose access depends on the caller's `TenantAccountJoin.role`
in the workspace named by the `workspace_id` path parameter.
Usage::
@openapi_ns.route("/workspaces/<string:workspace_id>/members")
class Members(Resource):
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
@require_workspace_role() # any member
def get(self, workspace_id: str): ...
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def post(self, workspace_id: str): ...
Non-member callers get 404 (matching `GET /openapi/v1/workspaces/<id>`)
so workspace IDs do not leak across tenants. A member without one of the
allowed roles gets 403.
"""
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TypeVar
from werkzeug.exceptions import Forbidden, NotFound
from extensions.ext_database import db
from libs.oauth_bearer import try_get_auth_ctx
from models.account import TenantAccountRole
from services.account_service import TenantService
F = TypeVar("F", bound=Callable[..., object])
def require_workspace_role(*allowed_roles: TenantAccountRole) -> Callable[[F], F]:
"""Gate a route on the caller's role in ``workspace_id``.
Pass no roles to require only membership. Pass one or more roles to
require the caller's role be in that set.
"""
allowed = frozenset(allowed_roles)
def deco(fn: F) -> F:
@wraps(fn)
def wrapper(*args: object, **kwargs: object) -> object:
ctx = try_get_auth_ctx()
if ctx is None or ctx.account_id is None:
raise RuntimeError(
"require_workspace_role called without account-bearer context; "
"stack validate_bearer + accept_subjects(SubjectType.ACCOUNT) above it"
)
workspace_id = kwargs.get("workspace_id")
if not workspace_id:
raise RuntimeError("require_workspace_role expects a 'workspace_id' route parameter")
role = TenantService.get_account_role_in_tenant(db.session, str(ctx.account_id), str(workspace_id))
if role is None:
raise NotFound("workspace not found")
if allowed and role not in allowed:
raise Forbidden("insufficient workspace role")
return fn(*args, **kwargs)
return wrapper # type: ignore[return-value]
return deco

View File

@ -1,11 +1,10 @@
from __future__ import annotations
from flask import request
from werkzeug.exceptions import Forbidden, NotFound, UnprocessableEntity
from werkzeug.exceptions import Forbidden, Unauthorized
from controllers.openapi.auth.data import AuthData
from extensions.ext_database import db
from libs.oauth_bearer import Scope, TokenType
from libs.oauth_bearer import Scope, TokenType, check_workspace_membership
from services.account_service import AccountService, TenantService
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
@ -18,39 +17,17 @@ def check_scope(data: AuthData) -> None:
raise Forbidden("insufficient_scope")
def check_workspace_member(data: AuthData) -> None:
"""Assert the caller belongs to the resolved tenant.
`load_workspace_role` stashes the membership role (None when the caller is
not a member or is inactive). A missing membership surfaces as 404, not
403, so workspace IDs don't leak across tenants.
"""
if data.tenant_role is None:
raise NotFound("workspace not found")
def check_workspace_mismatch(data: AuthData) -> None:
def check_membership(data: AuthData) -> None:
if data.tenant is None:
return
request_workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
if request_workspace_id and request_workspace_id != str(data.tenant.id):
raise UnprocessableEntity("workspace_id does not match app's workspace")
def check_workspace_role(data: AuthData) -> None:
if data.allowed_roles is None:
return
if data.tenant_role is None:
raise NotFound("workspace not found")
if data.tenant_role not in data.allowed_roles:
raise Forbidden("insufficient workspace role")
def check_app_api_enabled(data: AuthData) -> None:
if data.app is None:
return
if not data.app.enable_api:
raise Forbidden("service_api_disabled")
raise Unauthorized("tenant unset")
if data.account_id is None:
raise Unauthorized("account_id unset")
check_workspace_membership(
account_id=data.account_id,
tenant_id=data.tenant.id,
token_hash=data.token_hash,
membership_cache=data.tenants,
)
def check_app_access(data: AuthData) -> None:

View File

@ -26,12 +26,7 @@ from werkzeug.exceptions import BadRequest
from configs import dify_config
from controllers.common.schema import query_params_from_model
from controllers.console.wraps import (
account_initialization_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
AccountPayload,
@ -47,6 +42,7 @@ from controllers.openapi._models import (
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
from libs.rate_limit import (
LIMIT_APPROVE_CONSOLE,
@ -54,7 +50,6 @@ from libs.rate_limit import (
LIMIT_LOOKUP_PUBLIC,
rate_limit,
)
from models import Account
from services.account_service import TenantService
from services.oauth_device_flow import (
ACCOUNT_ISSUER_SENTINEL,
@ -211,12 +206,11 @@ class DeviceApproveApi(Resource):
@account_initialization_required
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
@with_current_user
@with_current_tenant_id
def post(self, tenant: str, account: Account):
def post(self):
payload = _validate_json(DeviceMutateRequest)
user_code = payload.user_code.strip().upper()
account, tenant = current_account_with_tenant()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)

View File

@ -5,8 +5,9 @@ endpoints. Account bearers (dfoa_) see every tenant they're a member of.
External SSO bearers (dfoe_) have no account_id and so see an empty list —
that matches /openapi/v1/account.
Member-management endpoints use ``guard_workspace`` which enforces
workspace membership and optional role requirements via the auth pipeline.
Member-management endpoints are gated by both `accept_subjects` (SSO out)
and `require_workspace_role` (membership / role lookup against the path's
``workspace_id``).
"""
from __future__ import annotations
@ -36,6 +37,7 @@ from controllers.openapi._models import (
)
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.role_gate import require_workspace_role
from extensions.ext_database import db
from libs.oauth_bearer import Scope, TokenType
from models import Account, Tenant, TenantAccountJoin
@ -150,7 +152,8 @@ class WorkspaceSwitchApi(Resource):
"""
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role()
def post(self, workspace_id: str, *, auth_data: AuthData):
account = _load_account(auth_data.account_id)
@ -176,7 +179,8 @@ class WorkspaceMembersApi(Resource):
@openapi_ns.doc(params=query_params_from_model(MemberListQuery))
@openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__])
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role()
def get(self, workspace_id: str, *, auth_data: AuthData):
try:
query = MemberListQuery.model_validate(request.args.to_dict(flat=True))
@ -198,11 +202,8 @@ class WorkspaceMembersApi(Resource):
@openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__])
@openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__])
@auth_router.guard_workspace(
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def post(self, workspace_id: str, *, auth_data: AuthData):
payload = _validate_body(MemberInvitePayload)
inviter = _load_account(auth_data.account_id)
@ -252,11 +253,8 @@ class WorkspaceMemberApi(Resource):
"""
@openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__])
@auth_router.guard_workspace(
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
operator = _load_account(auth_data.account_id)
tenant = _load_tenant(workspace_id)
@ -286,11 +284,8 @@ class WorkspaceMemberRoleApi(Resource):
@openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__])
@openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__])
@auth_router.guard_workspace(
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
payload = _validate_body(MemberRoleUpdatePayload)
operator = _load_account(auth_data.account_id)

View File

@ -6,7 +6,7 @@ from flask_restx import Resource
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field, TypeAdapter
from controllers.common.schema import query_params_from_model, register_schema_models
from controllers.common.schema import register_schema_models
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
@ -32,19 +32,8 @@ class AnnotationReplyActionPayload(BaseModel):
embedding_model_name: str = Field(description="Embedding model name")
class AnnotationListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, description="Number of annotations per page")
keyword: str = Field(default="", description="Keyword to search annotations")
register_schema_models(
service_api_ns,
AnnotationCreatePayload,
AnnotationReplyActionPayload,
AnnotationListQuery,
Annotation,
AnnotationList,
service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList
)
@ -111,7 +100,6 @@ class AnnotationReplyActionStatusApi(Resource):
class AnnotationListApi(Resource):
@service_api_ns.doc("list_annotations")
@service_api_ns.doc(description="List annotations for the application")
@service_api_ns.doc(params=query_params_from_model(AnnotationListQuery))
@service_api_ns.doc(
responses={
200: "Annotations retrieved successfully",
@ -126,18 +114,18 @@ class AnnotationListApi(Resource):
@validate_app_token
def get(self, app_model: App):
"""List annotations for the application."""
query = AnnotationListQuery.model_validate(request.args.to_dict(flat=True))
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app_model.id, query.page, query.limit, query.keyword
)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response = AnnotationList(
data=annotation_models,
has_more=len(annotation_list) == query.limit,
limit=query.limit,
has_more=len(annotation_list) == limit,
limit=limit,
total=total,
page=query.page,
page=page,
)
return response.model_dump(mode="json")

View File

@ -41,15 +41,6 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
def _resolve_agent_app_streaming(*, app_mode: AppMode, response_mode: str | None) -> bool:
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
if app_mode != AppMode.AGENT:
return response_mode == "streaming"
if response_mode == "blocking":
raise BadRequest("Agent App only supports streaming response mode.")
return True
class CompletionRequestPayload(BaseModel):
inputs: dict[str, Any]
query: str = Field(default="")
@ -206,7 +197,7 @@ class ChatApi(Resource):
Supports conversation management and both blocking and streaming response modes.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
@ -216,7 +207,7 @@ class ChatApi(Resource):
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = _resolve_agent_app_streaming(app_mode=app_mode, response_mode=payload.response_mode)
streaming = payload.response_mode == "streaming"
try:
response = AppGenerateService.generate(
@ -271,7 +262,7 @@ class ChatStopApi(Resource):
def post(self, app_model: App, end_user: EndUser, task_id: str):
"""Stop a running chat message generation."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
AppTaskService.stop_task(

View File

@ -155,7 +155,7 @@ class ConversationApi(Resource):
Supports pagination using last_id and limit parameters.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
query_args = ConversationListQuery.model_validate(request.args.to_dict())
@ -199,7 +199,7 @@ class ConversationDetailApi(Resource):
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
"""Delete a specific conversation."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -228,7 +228,7 @@ class ConversationRenameApi(Resource):
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
"""Rename a conversation or auto-generate a name."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -56,7 +56,7 @@ class MessageListApi(Resource):
Retrieves messages with pagination support using first_id.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
query_args = MessageListQuery.model_validate(request.args.to_dict())
@ -167,7 +167,7 @@ class MessageSuggestedApi(Resource):
"""
message_id_str = str(message_id)
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
try:

View File

@ -12,6 +12,7 @@ from typing import Self
from uuid import UUID
from flask import request, send_file
from flask_restx import marshal
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
@ -26,12 +27,7 @@ from controllers.common.errors import (
UnsupportedFileTypeError,
)
from controllers.common.fields import UrlResponse
from controllers.common.schema import (
query_params_from_model,
register_enum_models,
register_response_schema_models,
register_schema_models,
)
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import (
@ -48,13 +44,7 @@ from core.errors.error import ProviderTokenNotInitError
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.document_fields import (
DocumentListResponse,
DocumentResponse,
DocumentStatusListResponse,
)
from libs.helper import dump_response
from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import SegmentStatus
@ -117,44 +107,6 @@ class DocumentListQuery(BaseModel):
status: str | None = Field(default=None, description="Document status filter")
DOCUMENT_CREATE_BY_FILE_PARAMS = {
"dataset_id": "Dataset ID",
"file": {
"in": "formData",
"type": "file",
"required": True,
"description": "Document file to upload.",
},
"data": {
"in": "formData",
"type": "string",
"required": False,
"description": "Optional JSON string with document creation settings.",
},
}
DOCUMENT_UPDATE_BY_FILE_PARAMS = {
"dataset_id": "Dataset ID",
"document_id": "Document ID",
"file": {
"in": "formData",
"type": "file",
"required": False,
"description": "Replacement document file.",
},
"data": {
"in": "formData",
"type": "string",
"required": False,
"description": "Optional JSON string with document update settings.",
},
}
class DocumentAndBatchResponse(ResponseModel):
document: DocumentResponse
batch: str
register_enum_models(service_api_ns, RetrievalMethod)
register_schema_models(
@ -169,14 +121,7 @@ register_schema_models(
PreProcessingRule,
Segmentation,
)
register_response_schema_models(
service_api_ns,
UrlResponse,
DocumentResponse,
DocumentAndBatchResponse,
DocumentListResponse,
DocumentStatusListResponse,
)
register_response_schema_models(service_api_ns, UrlResponse)
def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[str, object], int]:
@ -243,7 +188,8 @@ def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[
raise ProviderNotInitializeError(ex.description)
document = documents[0]
return dump_response(DocumentAndBatchResponse, {"document": document, "batch": batch}), 200
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
@ -302,7 +248,8 @@ def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID
raise ProviderNotInitializeError(ex.description)
document = documents[0]
return dump_response(DocumentAndBatchResponse, {"document": document, "batch": batch}), 200
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create-by-text")
@ -320,9 +267,6 @@ class DocumentAddByTextApi(DatasetApiResource):
400: "Bad request - invalid parameters",
}
)
@service_api_ns.response(
200, "Document created successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
@ -352,9 +296,6 @@ class DeprecatedDocumentAddByTextApi(DatasetApiResource):
400: "Bad request - invalid parameters",
}
)
@service_api_ns.response(
200, "Document created successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
@ -378,9 +319,6 @@ class DocumentUpdateByTextApi(DatasetApiResource):
404: "Document not found",
}
)
@service_api_ns.response(
200, "Document updated successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
@ -409,9 +347,6 @@ class DeprecatedDocumentUpdateByTextApi(DatasetApiResource):
404: "Document not found",
}
)
@service_api_ns.response(
200, "Document updated successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
@ -428,7 +363,7 @@ class DocumentAddByFileApi(DatasetApiResource):
@service_api_ns.doc("create_document_by_file")
@service_api_ns.doc(description="Create a new document by uploading a file")
@service_api_ns.doc(consumes=["multipart/form-data"], params=DOCUMENT_CREATE_BY_FILE_PARAMS)
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@service_api_ns.doc(
responses={
200: "Document created successfully",
@ -436,9 +371,6 @@ class DocumentAddByFileApi(DatasetApiResource):
400: "Bad request - invalid file or parameters",
}
)
@service_api_ns.response(
200, "Document created successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
@ -530,7 +462,8 @@ class DocumentAddByFileApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
return dump_response(DocumentAndBatchResponse, {"document": document, "batch": batch}), 200
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
@ -606,7 +539,8 @@ def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
return dump_response(DocumentAndBatchResponse, {"document": document, "batch": document.batch}), 200
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
return documents_and_batch_fields, 200
@service_api_ns.route(
@ -624,7 +558,7 @@ class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
"Use PATCH /datasets/{dataset_id}/documents/{document_id} instead."
)
)
@service_api_ns.doc(consumes=["multipart/form-data"], params=DOCUMENT_UPDATE_BY_FILE_PARAMS)
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(
responses={
200: "Document updated successfully",
@ -632,9 +566,6 @@ class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
404: "Document not found",
}
)
@service_api_ns.response(
200, "Document updated successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
@ -646,7 +577,7 @@ class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
class DocumentListApi(DatasetApiResource):
@service_api_ns.doc("list_documents")
@service_api_ns.doc(description="List all documents in a dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", **query_params_from_model(DocumentListQuery)})
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@service_api_ns.doc(
responses={
200: "Documents retrieved successfully",
@ -654,9 +585,6 @@ class DocumentListApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
200, "Documents retrieved successfully", service_api_ns.models[DocumentListResponse.__name__]
)
def get(self, tenant_id, dataset_id: UUID):
dataset_id_str = str(dataset_id)
tenant_id = str(tenant_id)
@ -690,14 +618,14 @@ class DocumentListApi(DatasetApiResource):
)
response = {
"data": documents,
"data": marshal(documents, document_fields),
"has_more": len(documents) == query_params.limit,
"limit": query_params.limit,
"total": paginated_documents.total,
"page": query_params.page,
}
return dump_response(DocumentListResponse, response)
return response
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/download-zip")
@ -752,11 +680,6 @@ class DocumentIndexingStatusApi(DatasetApiResource):
404: "Dataset or documents not found",
}
)
@service_api_ns.response(
200,
"Indexing status retrieved successfully",
service_api_ns.models[DocumentStatusListResponse.__name__],
)
def get(self, tenant_id, dataset_id: UUID, batch: str):
dataset_id_str = str(dataset_id)
tenant_id = str(tenant_id)
@ -806,8 +729,9 @@ class DocumentIndexingStatusApi(DatasetApiResource):
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(document_dict)
return dump_response(DocumentStatusListResponse, {"data": documents_status})
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status}
return data
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/download")
@ -966,7 +890,7 @@ class DocumentApi(DatasetApiResource):
@service_api_ns.doc("update_document_by_file")
@service_api_ns.doc(description="Update an existing document by uploading a file")
@service_api_ns.doc(consumes=["multipart/form-data"], params=DOCUMENT_UPDATE_BY_FILE_PARAMS)
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(
responses={
200: "Document updated successfully",
@ -974,9 +898,6 @@ class DocumentApi(DatasetApiResource):
404: "Document not found",
}
)
@service_api_ns.response(
200, "Document updated successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID):

View File

@ -1,18 +1,15 @@
from typing import cast
from typing import Any
from uuid import UUID
from pydantic import BaseModel, Field, ValidationError, field_validator
from flask import request
from flask_restx import marshal
from pydantic import BaseModel, Field
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
from controllers.common.schema import (
query_params_from_model,
query_params_from_request,
register_response_schema_models,
register_schema_models,
)
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import (
@ -25,19 +22,10 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.segment_fields import (
ChildChunkDetailResponse,
ChildChunkListResponse,
SegmentDetailResponse,
SegmentResponse,
segment_response_with_summary,
segment_responses_with_summaries,
)
from fields.segment_fields import child_chunk_fields, segment_fields
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import dump_response
from libs.login import current_account_with_tenant
from models.dataset import Dataset, DocumentSegment
from models.dataset import Dataset
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
@ -46,27 +34,35 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from services.summary_index_service import SummaryIndexService
class SegmentCreateItemPayload(BaseModel):
content: str = Field(min_length=1)
answer: str | None = None
keywords: list[str] | None = None
attachment_ids: list[str] | None = None
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
"""Marshal a single segment and enrich it with summary content."""
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
@field_validator("content")
@classmethod
def validate_content(cls, value: str) -> str:
if not value.strip():
raise ValueError("Content is empty")
return value
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
"""Marshal multiple segments and enrich them with summary content (batch query)."""
segment_ids = [segment.id for segment in segments]
summaries: dict[str, str | None] = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
result: list[dict[str, Any]] = []
for segment in segments:
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict["summary"] = summaries.get(segment.id)
result.append(segment_dict)
return result
class SegmentCreatePayload(BaseModel):
segments: list[SegmentCreateItemPayload] = Field(min_length=1)
segments: list[dict[str, Any]] | None = None
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1)
page: int = Field(default=1, ge=1)
status: list[str] = Field(default_factory=list)
keyword: str | None = None
@ -81,31 +77,9 @@ class ChildChunkListQuery(BaseModel):
page: int = Field(default=1, ge=1)
class SegmentDocParams:
DATASET_DOCUMENT = {"dataset_id": "Dataset ID", "document_id": "Document ID"}
DATASET_DOCUMENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Segment ID"}
DATASET_DOCUMENT_PARENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Parent segment ID"}
DATASET_DOCUMENT_CHILD_CHUNK = {**DATASET_DOCUMENT_PARENT_SEGMENT, "child_chunk_id": "Child chunk ID"}
class SegmentCreateListResponse(ResponseModel):
data: list[SegmentResponse]
doc_form: str
class SegmentListResponse(ResponseModel):
data: list[SegmentResponse]
doc_form: str
total: int
has_more: bool
limit: int
page: int
register_schema_models(
service_api_ns,
SegmentCreatePayload,
SegmentCreateItemPayload,
SegmentListQuery,
SegmentUpdateArgs,
SegmentUpdatePayload,
@ -113,15 +87,6 @@ register_schema_models(
ChildChunkListQuery,
ChildChunkUpdatePayload,
)
register_response_schema_models(
service_api_ns,
SegmentResponse,
SegmentCreateListResponse,
SegmentListResponse,
SegmentDetailResponse,
ChildChunkDetailResponse,
ChildChunkListResponse,
)
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
@ -131,7 +96,7 @@ class SegmentApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[SegmentCreatePayload.__name__])
@service_api_ns.doc("create_segments")
@service_api_ns.doc(description="Create segments in a document")
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(
responses={
200: "Segments created successfully",
@ -140,11 +105,6 @@ class SegmentApi(DatasetApiResource):
404: "Dataset or document not found",
}
)
@service_api_ns.response(
200,
"Segments created successfully",
service_api_ns.models[SegmentCreateListResponse.__name__],
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
@ -184,35 +144,26 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
try:
payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {})
except ValidationError as e:
return {"error": str(e)}, 400
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
if segments_limit > 0 and len(payload.segments) > segments_limit:
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
segment_items = [segment.model_dump(exclude_none=True) for segment in payload.segments]
payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {})
if payload.segments is not None:
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
if segments_limit > 0 and len(payload.segments) > segments_limit:
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
for args_item in segment_items:
SegmentService.segment_create_args_validate(args_item, document)
segments = cast(list[DocumentSegment], SegmentService.multi_create_segment(segment_items, document, dataset))
segment_ids = [segment.id for segment in segments]
summaries: dict[str, str | None] = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(
segment_ids=segment_ids, dataset_id=dataset_id_str
)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
response = {
"data": segment_responses_with_summaries(segments, summaries),
"doc_form": document.doc_form,
}
return dump_response(SegmentCreateListResponse, response), 200
for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {
"data": _marshal_segments_with_summary(segments, dataset_id_str),
"doc_form": document.doc_form,
}, 200
else:
return {"error": "Segments is required"}, 400
@service_api_ns.expect(service_api_ns.models[SegmentListQuery.__name__])
@service_api_ns.doc("list_segments")
@service_api_ns.doc(description="List segments in a document")
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@service_api_ns.doc(params=query_params_from_model(SegmentListQuery))
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(
responses={
200: "Segments retrieved successfully",
@ -220,22 +171,12 @@ class SegmentApi(DatasetApiResource):
404: "Dataset or document not found",
}
)
@service_api_ns.response(
200,
"Segments retrieved successfully",
service_api_ns.models[SegmentListResponse.__name__],
)
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Get segments."""
# check dataset
args = query_params_from_request(
SegmentListQuery,
list_fields=("status",),
use_defaults_for_malformed_ints=True,
)
page = args.page
limit = args.limit
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
dataset_id_str = str(dataset_id)
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
@ -264,6 +205,13 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
args = SegmentListQuery.model_validate(
{
"status": request.args.getlist("status"),
"keyword": request.args.get("keyword"),
}
)
segments, total = SegmentService.get_segments(
document_id=document_id_str,
tenant_id=current_tenant_id,
@ -272,16 +220,9 @@ class SegmentApi(DatasetApiResource):
page=page,
limit=limit,
)
segment_ids = [segment.id for segment in segments]
summaries: dict[str, str | None] = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(
segment_ids=segment_ids, dataset_id=dataset_id_str
)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
response = {
"data": segment_responses_with_summaries(segments, summaries),
"data": _marshal_segments_with_summary(segments, dataset_id_str),
"doc_form": document.doc_form,
"total": total,
"has_more": len(segments) == limit,
@ -289,14 +230,16 @@ class SegmentApi(DatasetApiResource):
"page": page,
}
return dump_response(SegmentListResponse, response), 200
return response, 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
class DatasetSegmentApi(DatasetApiResource):
@service_api_ns.doc("delete_segment")
@service_api_ns.doc(description="Delete a specific segment")
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@service_api_ns.doc(
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to delete"}
)
@service_api_ns.doc(
responses={
204: "Segment deleted successfully",
@ -332,7 +275,9 @@ class DatasetSegmentApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
@service_api_ns.doc("update_segment")
@service_api_ns.doc(description="Update a specific segment")
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@service_api_ns.doc(
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to update"}
)
@service_api_ns.doc(
responses={
200: "Segment updated successfully",
@ -340,7 +285,6 @@ class DatasetSegmentApi(DatasetApiResource):
404: "Dataset, document, or segment not found",
}
)
@service_api_ns.response(200, "Segment updated successfully", service_api_ns.models[SegmentDetailResponse.__name__])
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
@ -384,16 +328,13 @@ class DatasetSegmentApi(DatasetApiResource):
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
summary = SummaryIndexService.get_segment_summary(segment_id=updated_segment.id, dataset_id=dataset_id_str)
response = {
"data": segment_response_with_summary(updated_segment, summary.summary_content if summary else None),
return {
"data": _marshal_segment_with_summary(updated_segment, dataset_id_str),
"doc_form": document.doc_form,
}
return dump_response(SegmentDetailResponse, response), 200
}, 200
@service_api_ns.doc("get_segment")
@service_api_ns.doc(description="Get a specific segment by ID")
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@service_api_ns.doc(
responses={
200: "Segment retrieved successfully",
@ -401,11 +342,6 @@ class DatasetSegmentApi(DatasetApiResource):
404: "Dataset, document, or segment not found",
}
)
@service_api_ns.response(
200,
"Segment retrieved successfully",
service_api_ns.models[SegmentDetailResponse.__name__],
)
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
@ -428,12 +364,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
response = {
"data": segment_response_with_summary(segment, summary.summary_content if summary else None),
"doc_form": document.doc_form,
}
return dump_response(SegmentDetailResponse, response), 200
return {"data": _marshal_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
@service_api_ns.route(
@ -445,7 +376,9 @@ class ChildChunkApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[ChildChunkCreatePayload.__name__])
@service_api_ns.doc("create_child_chunk")
@service_api_ns.doc(description="Create a new child chunk for a segment")
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
@service_api_ns.doc(
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"}
)
@service_api_ns.doc(
responses={
200: "Child chunk created successfully",
@ -453,11 +386,6 @@ class ChildChunkApi(DatasetApiResource):
404: "Dataset, document, or segment not found",
}
)
@service_api_ns.response(
200,
"Child chunk created successfully",
service_api_ns.models[ChildChunkDetailResponse.__name__],
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
@ -509,12 +437,14 @@ class ChildChunkApi(DatasetApiResource):
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@service_api_ns.expect(service_api_ns.models[ChildChunkListQuery.__name__])
@service_api_ns.doc("list_child_chunks")
@service_api_ns.doc(description="List child chunks for a segment")
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
@service_api_ns.doc(params=query_params_from_model(ChildChunkListQuery))
@service_api_ns.doc(
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"}
)
@service_api_ns.doc(
responses={
200: "Child chunks retrieved successfully",
@ -522,11 +452,6 @@ class ChildChunkApi(DatasetApiResource):
404: "Dataset, document, or segment not found",
}
)
@service_api_ns.response(
200,
"Child chunks retrieved successfully",
service_api_ns.models[ChildChunkListResponse.__name__],
)
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Get child chunks."""
@ -550,7 +475,13 @@ class ChildChunkApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
args = query_params_from_request(ChildChunkListQuery, use_defaults_for_malformed_ints=True)
args = ChildChunkListQuery.model_validate(
{
"limit": request.args.get("limit", default=20, type=int),
"keyword": request.args.get("keyword"),
"page": request.args.get("page", default=1, type=int),
}
)
page = args.page
limit = min(args.limit, 100)
@ -560,14 +491,13 @@ class ChildChunkApi(DatasetApiResource):
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
)
response = {
"data": child_chunks.items,
return {
"data": marshal(child_chunks.items, child_chunk_fields),
"total": child_chunks.total,
"total_pages": child_chunks.pages,
"page": page,
"limit": limit,
}
return dump_response(ChildChunkListResponse, response), 200
}, 200
@service_api_ns.route(
@ -578,7 +508,14 @@ class DatasetChildChunkApi(DatasetApiResource):
@service_api_ns.doc("delete_child_chunk")
@service_api_ns.doc(description="Delete a specific child chunk")
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
@service_api_ns.doc(
params={
"dataset_id": "Dataset ID",
"document_id": "Document ID",
"segment_id": "Parent segment ID",
"child_chunk_id": "Child chunk ID to delete",
}
)
@service_api_ns.doc(
responses={
204: "Child chunk deleted successfully",
@ -612,7 +549,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Segment not found.")
# validate segment belongs to the specified document
if segment.document_id != document_id_str:
if str(segment.document_id) != str(document_id_str):
raise NotFound("Document not found.")
child_chunk_id_str = str(child_chunk_id)
@ -624,7 +561,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Child chunk not found.")
# validate child chunk belongs to the specified segment
if child_chunk.segment_id != segment.id:
if str(child_chunk.segment_id) != str(segment.id):
raise NotFound("Child chunk not found.")
try:
@ -637,7 +574,14 @@ class DatasetChildChunkApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
@service_api_ns.doc("update_child_chunk")
@service_api_ns.doc(description="Update a specific child chunk")
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
@service_api_ns.doc(
params={
"dataset_id": "Dataset ID",
"document_id": "Document ID",
"segment_id": "Parent segment ID",
"child_chunk_id": "Child chunk ID to update",
}
)
@service_api_ns.doc(
responses={
200: "Child chunk updated successfully",
@ -645,11 +589,6 @@ class DatasetChildChunkApi(DatasetApiResource):
404: "Dataset, document, segment, or child chunk not found",
}
)
@service_api_ns.response(
200,
"Child chunk updated successfully",
service_api_ns.models[ChildChunkDetailResponse.__name__],
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
@ -677,7 +616,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Segment not found.")
# validate segment belongs to the specified document
if segment.document_id != document_id_str:
if str(segment.document_id) != str(document_id_str):
raise NotFound("Segment not found.")
child_chunk_id_str = str(child_chunk_id)
@ -689,7 +628,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Child chunk not found.")
# validate child chunk belongs to the specified segment
if child_chunk.segment_id != segment.id:
if str(child_chunk.segment_id) != str(segment.id):
raise NotFound("Child chunk not found.")
# validate args
@ -700,4 +639,4 @@ class DatasetChildChunkApi(DatasetApiResource):
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
return {"data": marshal(child_chunk, child_chunk_fields)}, 200

View File

@ -2,7 +2,7 @@ import logging
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.common.fields import SimpleResultResponse
@ -37,15 +37,6 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
def _resolve_agent_app_streaming(*, app_mode: AppMode, response_mode: str | None) -> bool:
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
if app_mode != AppMode.AGENT:
return response_mode == "streaming"
if response_mode == "blocking":
raise BadRequest("Agent App only supports streaming response mode.")
return True
class CompletionMessagePayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the completion")
query: str = Field(default="", description="Query text for completion")
@ -180,13 +171,13 @@ class ChatApi(WebApiResource):
)
def post(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
payload = ChatMessagePayload.model_validate(web_ns.payload or {})
args = payload.model_dump(exclude_none=True)
streaming = _resolve_agent_app_streaming(app_mode=app_mode, response_mode=payload.response_mode)
streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False
try:
@ -237,7 +228,7 @@ class ChatStopApi(WebApiResource):
@web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__])
def post(self, app_model: App, end_user: EndUser, task_id: str):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
AppTaskService.stop_task(

View File

@ -83,7 +83,7 @@ class ConversationListApi(WebApiResource):
)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
raw_args = request.args.to_dict()
@ -129,7 +129,7 @@ class ConversationApi(WebApiResource):
)
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -168,7 +168,7 @@ class ConversationRenameApi(WebApiResource):
)
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -206,7 +206,7 @@ class ConversationPinApi(WebApiResource):
@web_ns.response(200, "Conversation pinned successfully", web_ns.models[ResultResponse.__name__])
def patch(self, app_model: App, end_user: EndUser, c_id: UUID):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -237,7 +237,7 @@ class ConversationUnPinApi(WebApiResource):
@web_ns.response(200, "Conversation unpinned successfully", web_ns.models[ResultResponse.__name__])
def patch(self, app_model: App, end_user: EndUser, c_id: UUID):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -83,7 +83,7 @@ class MessageListApi(WebApiResource):
)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
raw_args = request.args.to_dict()
@ -225,7 +225,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
)
def get(self, app_model: App, end_user: EndUser, message_id: UUID):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
message_id_str = str(message_id)

View File

@ -9,7 +9,7 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from core.file import remote_fetcher
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from graphon.file import helpers as file_helpers
@ -60,10 +60,10 @@ class RemoteFileInfoApi(WebApiResource):
HTTPException: If the remote file cannot be accessed
"""
decoded_url = helpers.decode_remote_url(url, request.query_string)
resp = remote_fetcher.make_request("HEAD", decoded_url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
# failed back to get method
resp = remote_fetcher.make_request("GET", decoded_url, timeout=3)
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
info = RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
@ -112,9 +112,9 @@ class RemoteFileUploadApi(WebApiResource):
url = str(payload.url)
try:
resp = remote_fetcher.make_request("HEAD", url=url)
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = remote_fetcher.make_request("GET", url=url, timeout=3, follow_redirects=True)
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
@ -125,7 +125,7 @@ class RemoteFileUploadApi(WebApiResource):
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError
content = resp.content if resp.request.method == "GET" else remote_fetcher.make_request("GET", url).content
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
upload_file = FileService(db.engine).upload_file(

View File

@ -397,40 +397,39 @@ class CotAgentRunner(BaseAgentRunner, ABC):
current_scratchpad: AgentScratchpadUnit | None = None
for message in self.history_prompt_messages:
match message:
case AssistantPromptMessage():
if not current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
action_str="",
action=None,
observation=None,
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
action_str="",
action=None,
observation=None,
)
scratchpads.append(current_scratchpad)
if message.tool_calls:
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments),
)
scratchpads.append(current_scratchpad)
if message.tool_calls:
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments),
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except Exception:
logger.exception("Failed to parse tool call from assistant message")
case ToolPromptMessage():
if current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad.observation = message.content
else:
raise NotImplementedError("expected str type")
case UserPromptMessage():
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
scratchpads = []
current_scratchpad = None
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except Exception:
logger.exception("Failed to parse tool call from assistant message")
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad.observation = message.content
else:
raise NotImplementedError("expected str type")
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
scratchpads = []
current_scratchpad = None
result.append(message)
result.append(message)
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))

View File

@ -39,17 +39,16 @@ class CotCompletionAgentRunner(CotAgentRunner):
historic_prompt = ""
for message in historic_prompt_messages:
match message:
case UserPromptMessage():
historic_prompt += f"Question: {message.content}\n\n"
case AssistantPromptMessage():
if isinstance(message.content, str):
historic_prompt += message.content + "\n\n"
elif isinstance(message.content, list):
for content in message.content:
if not isinstance(content, TextPromptMessageContent):
continue
historic_prompt += content.data
if isinstance(message, UserPromptMessage):
historic_prompt += f"Question: {message.content}\n\n"
elif isinstance(message, AssistantPromptMessage):
if isinstance(message.content, str):
historic_prompt += message.content + "\n\n"
elif isinstance(message.content, list):
for content in message.content:
if not isinstance(content, TextPromptMessageContent):
continue
historic_prompt += content.data
return historic_prompt

View File

@ -237,9 +237,7 @@ class EasyUIBasedAppConfig(AppConfig):
"""
app_model_config_from: EasyUIBasedAppModelConfigFrom
# Optional: an Agent App has no legacy app_model_config row, so the id may be
# absent (persistence then stores NULL for the conversation's id).
app_model_config_id: str | None = None
app_model_config_id: str
app_model_config_dict: dict[str, Any]
model: ModelConfigEntity
prompt_template: PromptTemplateEntity

View File

@ -1,6 +1,4 @@
import json
import re
from typing import Any
from core.app.app_config.entities import RagPipelineVariableEntity
from graphon.variables.input_entities import VariableEntity
@ -22,32 +20,10 @@ class WorkflowVariablesConfigManager:
# variables
for variable in user_input_form:
cls._normalize_json_schema(variable)
variables.append(VariableEntity.model_validate(variable))
return variables
@staticmethod
def _normalize_json_schema(variable: dict[str, Any]) -> None:
"""
Normalize ``json_schema`` from a JSON string to a dict.
The workflow graph is stored as JSON in the database. When a JSON
object variable carries a ``json_schema`` field, nested dicts are
preserved correctly, but older data or certain serialization paths
may store it as a JSON *string* instead of a native dict.
``VariableEntity.json_schema`` expects ``dict | None``, so we
deserialize the string here before handing it to Pydantic.
"""
json_schema = variable.get("json_schema")
if isinstance(json_schema, str):
try:
variable["json_schema"] = json.loads(json_schema)
except (json.JSONDecodeError, TypeError):
# Leave as-is; Pydantic validation will surface the error.
pass
@classmethod
def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]:
"""

View File

@ -134,18 +134,17 @@ class AdvancedChatAppGenerateResponseConverter(
"created_at": chunk.created_at,
}
match sub_stream_response:
case MessageEndStreamResponse():
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
case ErrorStreamResponse():
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
case NodeStartStreamResponse() | NodeFinishStreamResponse():
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
case _:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -161,17 +161,16 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
stream=stream,
)
match user:
case EndUser():
self._user_id = user.id
user_session_id = user.session_id
self._created_by_role = CreatorUserRole.END_USER
case Account():
self._user_id = user.id
user_session_id = user.id
self._created_by_role = CreatorUserRole.ACCOUNT
case _:
raise NotImplementedError(f"User type not supported: {type(user)}")
if isinstance(user, EndUser):
self._user_id = user.id
user_session_id = user.session_id
self._created_by_role = CreatorUserRole.END_USER
elif isinstance(user, Account):
self._user_id = user.id
user_session_id = user.id
self._created_by_role = CreatorUserRole.ACCOUNT
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_system_variables = build_system_variables(
query=message.query,

View File

@ -1,106 +0,0 @@
"""Build the EasyUI-style app config for an Agent App from its Agent Soul.
An Agent App has no legacy ``app_model_config``: its model / prompt live in the
bound Agent Soul snapshot. To ride the existing chat message + SSE pipeline we
synthesize an ``app_model_config``-shaped dict from the Soul (model + system
prompt) plus any app-level feature flags (opening statement, follow-up, …)
stored on ``app_model_config`` when present, then reuse the same sub-managers
the chat app type uses.
"""
from typing import Any, cast
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager
from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager
from core.app.app_config.entities import (
EasyUIBasedAppConfig,
EasyUIBasedAppModelConfigFrom,
PromptTemplateEntity,
)
from models.agent_config_entities import AgentSoulConfig
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
class AgentAppConfig(EasyUIBasedAppConfig):
"""Agent App config entity (EasyUI-shaped so it rides the chat pipeline).
``app_model_config_id`` is inherited as ``str | None``: an Agent App may have
no legacy ``app_model_config`` row, in which case persistence stores ``NULL``
for the conversation's ``app_model_config_id``.
"""
class AgentAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(
cls,
*,
app_model: App,
agent_soul: AgentSoulConfig,
app_model_config: AppModelConfig | None = None,
conversation: Conversation | None = None,
) -> AgentAppConfig:
"""Build the Agent App config from the Agent Soul (+ optional feature flags)."""
config_dict = cls._synthesize_config_dict(agent_soul, app_model_config)
# The synthesized dict is shaped like an app_model_config; the EasyUI
# sub-managers type their param as AppModelConfigDict (a TypedDict).
typed_config = cast(AppModelConfigDict, config_dict)
app_mode = AppMode.value_of(app_model.mode)
app_config = AgentAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=app_mode,
# The config is derived from the Agent Soul snapshot, not a legacy
# app_model_config row; the id is informational only.
app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG,
app_model_config_id=app_model_config.id if app_model_config else None,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(config=typed_config),
prompt_template=PromptTemplateConfigManager.convert(config=typed_config),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
dataset=DatasetConfigManager.convert(config=typed_config),
additional_features=cls.convert_features(config_dict, app_mode),
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
config=typed_config
)
return app_config
@staticmethod
def _synthesize_config_dict(
agent_soul: AgentSoulConfig,
app_model_config: AppModelConfig | None,
) -> dict[str, Any]:
"""Shape a Soul + feature flags into an ``app_model_config``-style dict.
Feature flags (opening statement / follow-up / tts / stt / citations /
moderation / annotation) come from ``app_model_config`` when present
(Q3: stored there), otherwise defaults; model + prompt always come from
the Agent Soul (the single source of truth for those).
"""
base: dict[str, Any] = dict(app_model_config.to_dict()) if app_model_config else {}
model = agent_soul.model
if model is not None:
base["model"] = {
"provider": model.model_provider,
"name": model.model,
"mode": "chat",
"completion_params": model.model_settings.model_dump(mode="json", exclude_none=True),
}
# The Agent Soul system prompt rides the EasyUI "simple" prompt slot; the
# agent backend is the real prompt authority, this only feeds the chat
# pipeline's bookkeeping (token counting, persistence).
base["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
base["pre_prompt"] = agent_soul.prompt.system_prompt or ""
# Agent App takes the user message directly; no completion-style inputs form.
base.setdefault("user_input_form", [])
return base
__all__ = ["AgentAppConfig", "AgentAppConfigManager"]

View File

@ -1,327 +0,0 @@
"""Agent App generator: orchestrate one conversation turn for an Agent App.
Mirrors the agent_chat generator (conversation + message + queue + streamed
response over the EasyUI chat pipeline), but the backing config comes from the
bound Agent Soul and the answer is produced by ``AgentAppRunner`` calling the
dify-agent backend rather than an in-process LLM/ReAct loop.
"""
from __future__ import annotations
import contextvars
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any
from flask import Flask, current_app
from sqlalchemy import select
from clients.agent_backend import AgentBackendRunEventAdapter
from clients.agent_backend.factory import create_agent_backend_run_client
from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.apps.agent_app.app_config_manager import AgentAppConfigManager
from core.app.apps.agent_app.app_runner import AgentAppRunner
from core.app.apps.agent_app.generate_response_converter import AgentAppGenerateResponseConverter
from core.app.apps.agent_app.runtime_request_builder import AgentAppRuntimeRequestBuilder
from core.app.apps.agent_app.session_store import AgentAppRuntimeSessionStore
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import (
AgentAppGenerateEntity,
DifyRunContext,
InvokeFrom,
UserFrom,
)
from core.app.llm.model_access import build_dify_model_access
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from models import Account, App, EndUser, Message
from models.agent import Agent, AgentConfigSnapshot, AgentScope, AgentSource, AgentStatus
from models.agent_config_entities import AgentSoulConfig
from services.conversation_service import ConversationService
logger = logging.getLogger(__name__)
class AgentAppGeneratorError(ValueError):
"""Raised when an Agent App turn cannot be set up."""
class AgentAppGenerator(MessageBasedAppGenerator):
def generate(
self,
*,
app_model: App,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[Mapping | str, None, None]:
if not streaming:
raise AgentAppGeneratorError("Agent App only supports streaming mode")
query = args.get("query")
if not isinstance(query, str) or not query.strip():
raise AgentAppGeneratorError("query is required")
query = query.replace("\x00", "")
inputs = args["inputs"]
# Resolve the bound roster Agent + its published Agent Soul snapshot.
agent, snapshot, agent_soul = self._resolve_agent(app_model)
conversation = None
conversation_id = args.get("conversation_id")
if conversation_id:
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
# Build the EasyUI-shaped config from the Agent Soul so the chat pipeline
# can persist usage; the answer itself comes from the agent backend.
app_model_config = app_model.app_model_config
app_config = AgentAppConfigManager.get_app_config(
app_model=app_model,
agent_soul=agent_soul,
app_model_config=app_model_config,
conversation=conversation,
)
model_conf = ModelConfigConverter.convert(app_config)
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
application_generate_entity = AgentAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=model_conf,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=[],
parent_message_id=(
args.get("parent_message_id")
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
else UUID_NIL
),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras={"auto_generate_conversation_name": args.get("auto_generate_name", True)},
call_depth=0,
trace_manager=trace_manager,
agent_id=agent.id,
agent_config_snapshot_id=snapshot.id,
)
conversation, message = self._init_generate_records(application_generate_entity, conversation)
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
context = contextvars.copy_context()
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
"user_from": UserFrom.ACCOUNT if isinstance(user, Account) else UserFrom.END_USER,
},
)
worker_thread.start()
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
return AgentAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,
*,
flask_app: Flask,
context: contextvars.Context,
application_generate_entity: AgentAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
user_from: UserFrom,
) -> None:
from libs.flask_utils import preserve_flask_contexts
with preserve_flask_contexts(flask_app, context_vars=context):
try:
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
app_config = application_generate_entity.app_config
# Apply app-level input guards (content moderation + annotation
# reply) before reaching the Agent backend, mirroring the EasyUI
# chat / agent-chat runners. These can short-circuit the turn.
app_model = db.session.get(App, app_config.app_id)
if app_model is None:
raise AgentAppGeneratorError("App not found")
handled, query = self._run_input_guards(
application_generate_entity=application_generate_entity,
app_model=app_model,
message=message,
queue_manager=queue_manager,
)
if handled:
return
dify_context = DifyRunContext(
tenant_id=app_config.tenant_id,
app_id=app_config.app_id,
user_id=application_generate_entity.user_id,
user_from=user_from,
invoke_from=application_generate_entity.invoke_from,
)
credentials_provider, _ = build_dify_model_access(dify_context)
_, _, agent_soul = self._resolve_agent_by_id(
tenant_id=app_config.tenant_id,
agent_id=application_generate_entity.agent_id,
snapshot_id=application_generate_entity.agent_config_snapshot_id,
)
runner = AgentAppRunner(
request_builder=AgentAppRuntimeRequestBuilder(credentials_provider=credentials_provider),
agent_backend_client=create_agent_backend_run_client(
base_url=dify_config.AGENT_BACKEND_BASE_URL,
use_fake=dify_config.AGENT_BACKEND_USE_FAKE,
fake_scenario=dify_config.AGENT_BACKEND_FAKE_SCENARIO,
),
event_adapter=AgentBackendRunEventAdapter(),
session_store=AgentAppRuntimeSessionStore(),
)
runner.run(
dify_context=dify_context,
agent_id=application_generate_entity.agent_id,
agent_config_snapshot_id=application_generate_entity.agent_config_snapshot_id,
agent_soul=agent_soul,
conversation_id=conversation.id,
query=query,
message_id=message.id,
model_name=application_generate_entity.model_conf.model,
queue_manager=queue_manager,
)
except GenerateTaskStoppedError:
pass
except Exception as e:
logger.exception("Unknown Error in Agent App generate worker")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
def _run_input_guards(
self,
*,
application_generate_entity: AgentAppGenerateEntity,
app_model: App,
message: Message,
queue_manager: AppQueueManager,
) -> tuple[bool, str]:
"""Apply input moderation + annotation reply before the backend call.
Returns ``(handled, query)``: when ``handled`` is True a direct answer
has already been published (a blocked/preset moderation response or a
matched annotation) and the backend turn must be skipped. Otherwise
``query`` is the possibly moderation-overridden query to send onward.
"""
from core.app.apps.agent_app.app_runner import publish_text_answer
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
app_config = application_generate_entity.app_config
model_name = application_generate_entity.model_conf.model
query = application_generate_entity.query
# content moderation (sensitive_word_avoidance); a blocked input yields a
# preset answer, an "overridden" action returns a sanitized query.
try:
_, _, query = InputModeration().check(
app_id=app_config.app_id,
tenant_id=app_config.tenant_id,
app_config=app_config,
inputs=dict(application_generate_entity.inputs),
query=query or "",
message_id=message.id,
trace_manager=application_generate_entity.trace_manager,
)
except ModerationError as e:
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=str(e))
return True, query
# annotation reply: a matching annotation answers the turn deterministically.
if query:
annotation_reply = AnnotationReplyFeature().query(
app_record=app_model,
message=message,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER,
)
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=annotation_reply.content)
return True, query
return False, query
def _resolve_agent(self, app_model: App) -> tuple[Agent, AgentConfigSnapshot, AgentSoulConfig]:
agent = db.session.scalar(
select(Agent).where(
Agent.app_id == app_model.id,
Agent.scope == AgentScope.ROSTER,
Agent.source == AgentSource.AGENT_APP,
Agent.status == AgentStatus.ACTIVE,
)
)
if agent is None:
raise AgentAppGeneratorError("Agent App has no bound Agent")
return self._resolve_agent_by_id(
tenant_id=app_model.tenant_id, agent_id=agent.id, snapshot_id=agent.active_config_snapshot_id
)
@staticmethod
def _resolve_agent_by_id(
*, tenant_id: str, agent_id: str, snapshot_id: str | None
) -> tuple[Agent, AgentConfigSnapshot, AgentSoulConfig]:
agent = db.session.scalar(select(Agent).where(Agent.id == agent_id, Agent.tenant_id == tenant_id))
if agent is None:
raise AgentAppGeneratorError("Agent not found")
if not snapshot_id:
raise AgentAppGeneratorError("Agent has no published version")
snapshot = db.session.scalar(select(AgentConfigSnapshot).where(AgentConfigSnapshot.id == snapshot_id))
if snapshot is None:
raise AgentAppGeneratorError("Agent published version not found")
agent_soul = AgentSoulConfig.model_validate(snapshot.config_snapshot_dict)
return agent, snapshot, agent_soul
__all__ = ["AgentAppGenerator", "AgentAppGeneratorError"]

View File

@ -1,200 +0,0 @@
"""Agent App runner: drive one conversation turn through the dify-agent backend.
Unlike the legacy ``AgentChatAppRunner`` (which runs an in-process ReAct loop),
this runner delegates to the Agent backend: build the run request from the
Agent Soul + conversation, create the run, consume its event stream, and
republish the assistant answer as chat queue events so the existing
EasyUI chat task pipeline persists the message and streams SSE. The conversation
``session_snapshot`` is saved on success for multi-turn continuity (S3).
"""
from __future__ import annotations
import json
import logging
from typing import Any
from pydantic import JsonValue
from clients.agent_backend import (
AgentBackendError,
AgentBackendInternalEventType,
AgentBackendRunClient,
AgentBackendRunEventAdapter,
AgentBackendRunSucceededInternalEvent,
AgentBackendStreamInternalEvent,
)
from core.app.apps.agent_app.runtime_request_builder import (
AgentAppRuntimeBuildContext,
AgentAppRuntimeRequestBuilder,
)
from core.app.apps.agent_app.session_store import AgentAppRuntimeSessionStore, AgentAppSessionScope
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import DifyRunContext
from core.app.entities.queue_entities import QueueLLMChunkEvent, QueueMessageEndEvent
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
from models.agent_config_entities import AgentSoulConfig
logger = logging.getLogger(__name__)
def publish_text_answer(*, queue_manager: AppQueueManager, model_name: str, answer: str) -> None:
"""Publish a complete assistant answer as one chunk + message-end.
The EasyUI chat task pipeline consumes a QueueLLMChunkEvent stream followed
by a QueueMessageEndEvent; emitting the whole answer as a single chunk lets
both the backend-produced answer and short-circuited answers (moderation /
annotation reply) share the exact same persistence + SSE path.
"""
chunk = LLMResultChunk(
model=model_name,
prompt_messages=[],
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=answer)),
)
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_name,
prompt_messages=[],
message=AssistantPromptMessage(content=answer),
usage=LLMUsage.empty_usage(),
),
),
PublishFrom.APPLICATION_MANAGER,
)
class AgentAppRunner:
"""Runs one Agent App conversation turn against the Agent backend."""
def __init__(
self,
*,
request_builder: AgentAppRuntimeRequestBuilder,
agent_backend_client: AgentBackendRunClient,
event_adapter: AgentBackendRunEventAdapter,
session_store: AgentAppRuntimeSessionStore,
) -> None:
self._request_builder = request_builder
self._agent_backend_client = agent_backend_client
self._event_adapter = event_adapter
self._session_store = session_store
def run(
self,
*,
dify_context: DifyRunContext,
agent_id: str,
agent_config_snapshot_id: str,
agent_soul: AgentSoulConfig,
conversation_id: str,
query: str,
message_id: str,
model_name: str,
queue_manager: AppQueueManager,
) -> None:
scope = AgentAppSessionScope(
tenant_id=dify_context.tenant_id,
app_id=dify_context.app_id,
conversation_id=conversation_id,
agent_id=agent_id,
agent_config_snapshot_id=agent_config_snapshot_id,
)
session_snapshot = self._session_store.load_active_snapshot(scope)
runtime = self._request_builder.build(
AgentAppRuntimeBuildContext(
dify_context=dify_context,
agent_id=agent_id,
agent_config_snapshot_id=agent_config_snapshot_id,
agent_soul=agent_soul,
conversation_id=conversation_id,
user_query=query,
idempotency_key=message_id,
session_snapshot=session_snapshot,
)
)
create_response = self._agent_backend_client.create_run(runtime.request)
terminal = self._consume_stream(create_response.run_id, queue_manager=queue_manager)
if not isinstance(terminal, AgentBackendRunSucceededInternalEvent):
error = getattr(terminal, "error", None) or "Agent backend run did not complete successfully."
raise AgentBackendError(str(error))
answer = self._extract_answer(terminal.output)
self._publish_answer(queue_manager=queue_manager, model_name=model_name, answer=answer)
self._save_session(scope=scope, backend_run_id=terminal.run_id, snapshot=terminal.session_snapshot)
def _consume_stream(self, run_id: str, *, queue_manager: AppQueueManager):
terminal = None
for public_event in self._agent_backend_client.stream_events(run_id):
if queue_manager.is_stopped():
self._cancel_run(run_id)
raise GenerateTaskStoppedError()
for internal_event in self._event_adapter.adapt(public_event):
if queue_manager.is_stopped():
self._cancel_run(run_id)
raise GenerateTaskStoppedError()
if internal_event.type in (
AgentBackendInternalEventType.RUN_STARTED,
AgentBackendInternalEventType.STREAM_EVENT,
):
# Stream deltas are accumulated by the backend into the
# terminal output; token-level forwarding is an S3 refinement.
if isinstance(internal_event, AgentBackendStreamInternalEvent):
continue
continue
terminal = internal_event
break
if terminal is not None:
break
return terminal
def _cancel_run(self, run_id: str) -> None:
try:
self._agent_backend_client.cancel_run(run_id)
except Exception:
logger.warning("Failed to cancel stopped Agent App backend run: run_id=%s", run_id, exc_info=True)
def _publish_answer(self, *, queue_manager: AppQueueManager, model_name: str, answer: str) -> None:
# MVP: emit the full answer as a single chunk + message-end. The chat
# task pipeline streams the chunk over SSE and persists the message.
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=answer)
def _save_session(self, *, scope: AgentAppSessionScope, backend_run_id: str, snapshot: Any) -> None:
try:
self._session_store.save_active_snapshot(scope=scope, backend_run_id=backend_run_id, snapshot=snapshot)
except Exception:
logger.warning(
"Failed to persist Agent App conversation session snapshot: "
"tenant_id=%s app_id=%s conversation_id=%s agent_id=%s",
scope.tenant_id,
scope.app_id,
scope.conversation_id,
scope.agent_id,
exc_info=True,
)
@staticmethod
def _extract_answer(output: JsonValue) -> str:
"""Normalize the backend's terminal output to assistant text.
Free-text Agent Apps return a plain string; if a structured output is
configured the value is a JSON object, which we serialize so the chat
message always has a string body.
"""
if isinstance(output, str):
return output
if isinstance(output, dict):
text = output.get("text")
if isinstance(text, str):
return text
return json.dumps(output, ensure_ascii=False)
return json.dumps(output, ensure_ascii=False)
__all__ = ["AgentAppRunner", "publish_text_answer"]

View File

@ -1,15 +0,0 @@
"""Response converter for the Agent App type.
The Agent App streams the same chatbot response shape as the chat / agent-chat
app types, so it reuses that converter wholesale; kept as a distinct subclass so
the app type owns its converter and can diverge later.
"""
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
class AgentAppGenerateResponseConverter(AgentChatAppGenerateResponseConverter):
pass
__all__ = ["AgentAppGenerateResponseConverter"]

View File

@ -1,177 +0,0 @@
"""Build dify-agent run requests for one Agent App conversation turn.
Mirrors the workflow ``WorkflowAgentRuntimeRequestBuilder`` but for the Agent
App surface: the user prompt is the chat message (no workflow-node job / no
previous-node context), and multi-turn continuity flows through the
conversation-keyed ``session_snapshot`` plus the history layer.
"""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Protocol, cast
from agenton.compositor import CompositorSessionSnapshot
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
from dify_agent.protocol import CreateRunRequest
from clients.agent_backend import (
AgentBackendAgentAppRunInput,
AgentBackendModelConfig,
AgentBackendRunRequestBuilder,
redact_for_agent_backend_log,
)
from core.app.entities.app_invoke_entities import DifyRunContext
from core.workflow.nodes.agent_v2.plugin_tools_builder import (
WorkflowAgentPluginToolsBuilder,
WorkflowAgentPluginToolsBuildError,
)
from models.agent_config_entities import AgentSoulConfig
from models.provider_ids import ModelProviderID
class AgentAppRuntimeRequestBuildError(ValueError):
"""Raised when Agent App state cannot be mapped to a valid run request."""
def __init__(self, error_code: str, message: str) -> None:
self.error_code = error_code
super().__init__(message)
class CredentialsProvider(Protocol):
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: ...
@dataclass(frozen=True, slots=True)
class AgentAppRuntimeBuildContext:
dify_context: DifyRunContext
agent_id: str
agent_config_snapshot_id: str
agent_soul: AgentSoulConfig
conversation_id: str
user_query: str
idempotency_key: str
session_snapshot: CompositorSessionSnapshot | None = None
@dataclass(frozen=True, slots=True)
class AgentAppRuntimeRequest:
request: CreateRunRequest
redacted_request: dict[str, Any]
metadata: dict[str, Any]
class AgentAppRuntimeRequestBuilder:
"""Build dify-agent run requests from Agent App conversation state."""
def __init__(
self,
*,
credentials_provider: CredentialsProvider,
request_builder: AgentBackendRunRequestBuilder | None = None,
plugin_tools_builder: WorkflowAgentPluginToolsBuilder | None = None,
) -> None:
self._credentials_provider = credentials_provider
self._request_builder = request_builder or AgentBackendRunRequestBuilder()
self._plugin_tools_builder = plugin_tools_builder or WorkflowAgentPluginToolsBuilder()
def build(self, context: AgentAppRuntimeBuildContext) -> AgentAppRuntimeRequest:
agent_soul = context.agent_soul
if agent_soul.model is None:
raise AgentAppRuntimeRequestBuildError(
"agent_model_not_configured",
"Agent App requires the Agent Soul model to be configured.",
)
metadata = self._build_metadata(context)
credentials = self._credentials_provider.fetch(agent_soul.model.model_provider, agent_soul.model.model)
try:
tools_layer = self._plugin_tools_builder.build(
tenant_id=context.dify_context.tenant_id,
app_id=context.dify_context.app_id,
user_id=context.dify_context.user_id,
tools=agent_soul.tools,
invoke_from=context.dify_context.invoke_from,
)
except WorkflowAgentPluginToolsBuildError as error:
raise AgentAppRuntimeRequestBuildError(error.error_code, str(error)) from error
if tools_layer is not None:
metadata["agent_tools"] = {
"dify_tool_count": len(tools_layer.tools),
"dify_tool_names": [tool.name or tool.tool_name for tool in tools_layer.tools],
}
request = self._request_builder.build_for_agent_app(
AgentBackendAgentAppRunInput(
model=AgentBackendModelConfig(
plugin_id=self._plugin_daemon_plugin_id(
plugin_id=agent_soul.model.plugin_id,
model_provider=agent_soul.model.model_provider,
),
model_provider=self._plugin_daemon_provider_name(agent_soul.model.model_provider),
model=agent_soul.model.model,
credentials=self._normalize_credentials(credentials),
model_settings=agent_soul.model.model_settings.model_dump(mode="json", exclude_none=True),
),
execution_context=DifyExecutionContextLayerConfig(
tenant_id=context.dify_context.tenant_id,
user_id=context.dify_context.user_id,
app_id=context.dify_context.app_id,
conversation_id=context.conversation_id,
agent_id=context.agent_id,
agent_config_version_id=context.agent_config_snapshot_id,
invoke_from="agent_app",
),
agent_soul_prompt=agent_soul.prompt.system_prompt or None,
user_prompt=context.user_query,
tools=tools_layer,
session_snapshot=context.session_snapshot,
idempotency_key=context.idempotency_key,
metadata=metadata,
)
)
redacted = cast(dict[str, Any], redact_for_agent_backend_log(request))
return AgentAppRuntimeRequest(request=request, redacted_request=redacted, metadata=metadata)
@staticmethod
def _build_metadata(context: AgentAppRuntimeBuildContext) -> dict[str, Any]:
return {
"tenant_id": context.dify_context.tenant_id,
"app_id": context.dify_context.app_id,
"conversation_id": context.conversation_id,
"agent_id": context.agent_id,
"agent_config_snapshot_id": context.agent_config_snapshot_id,
}
@staticmethod
def _plugin_daemon_plugin_id(*, plugin_id: str, model_provider: str) -> str:
"""Return the transport plugin id expected by plugin-daemon headers."""
if plugin_id.count("/") == 1:
return plugin_id
if plugin_id:
return ModelProviderID(plugin_id).plugin_id
return ModelProviderID(model_provider).plugin_id
@staticmethod
def _plugin_daemon_provider_name(model_provider: str) -> str:
"""Return the provider name expected by plugin-daemon dispatch payloads."""
return ModelProviderID(model_provider).provider_name
@staticmethod
def _normalize_credentials(credentials: Mapping[str, Any]) -> dict[str, str | int | float | bool | None]:
normalized: dict[str, str | int | float | bool | None] = {}
for key, value in credentials.items():
if isinstance(value, str | int | float | bool) or value is None:
normalized[key] = value
else:
normalized[key] = str(value)
return normalized
__all__ = [
"AgentAppRuntimeBuildContext",
"AgentAppRuntimeRequest",
"AgentAppRuntimeRequestBuildError",
"AgentAppRuntimeRequestBuilder",
]

View File

@ -1,106 +0,0 @@
"""Conversation-keyed Agent backend session store for the Agent App type.
Shares the unified ``agent_runtime_sessions`` table with the workflow Agent
Node store, but owns rows with ``owner_type = conversation``: one Agent App
conversation maps to one Agent session, so multi-turn chat re-enters the same
``session_snapshot``. Cross-conversation memory (PRD Global / Per app) is a
phase-2 concern and not modeled here.
"""
from __future__ import annotations
from dataclasses import dataclass
from agenton.compositor import CompositorSessionSnapshot
from sqlalchemy import select
from core.db.session_factory import session_factory
from libs.datetime_utils import naive_utc_now
from models.agent import (
AgentRuntimeSession,
AgentRuntimeSessionOwnerType,
AgentRuntimeSessionStatus,
)
@dataclass(frozen=True, slots=True)
class AgentAppSessionScope:
"""Identity of one Agent App conversation session."""
tenant_id: str
app_id: str
conversation_id: str
agent_id: str
agent_config_snapshot_id: str
class AgentAppRuntimeSessionStore:
"""Persists Agent backend session snapshots for Agent App conversations."""
def load_active_snapshot(self, scope: AgentAppSessionScope) -> CompositorSessionSnapshot | None:
with session_factory.create_session() as session:
row = session.scalar(self._active_stmt(scope))
if row is None:
return None
return CompositorSessionSnapshot.model_validate_json(row.session_snapshot)
def save_active_snapshot(
self,
*,
scope: AgentAppSessionScope,
backend_run_id: str,
snapshot: CompositorSessionSnapshot | None,
) -> None:
if snapshot is None:
return
snapshot_json = snapshot.model_dump_json()
with session_factory.create_session() as session:
row = session.scalar(self._scope_stmt(scope))
if row is None:
row = AgentRuntimeSession(
tenant_id=scope.tenant_id,
app_id=scope.app_id,
owner_type=AgentRuntimeSessionOwnerType.CONVERSATION,
agent_id=scope.agent_id,
agent_config_snapshot_id=scope.agent_config_snapshot_id,
conversation_id=scope.conversation_id,
backend_run_id=backend_run_id,
session_snapshot=snapshot_json,
composition_layer_specs="[]",
status=AgentRuntimeSessionStatus.ACTIVE,
)
session.add(row)
else:
row.backend_run_id = backend_run_id
row.session_snapshot = snapshot_json
row.status = AgentRuntimeSessionStatus.ACTIVE
row.cleaned_at = None
session.commit()
def mark_cleaned(self, *, scope: AgentAppSessionScope, backend_run_id: str | None = None) -> None:
with session_factory.create_session() as session:
row = session.scalar(self._active_stmt(scope))
if row is None:
return
if backend_run_id is not None:
row.backend_run_id = backend_run_id
row.status = AgentRuntimeSessionStatus.CLEANED
row.cleaned_at = naive_utc_now()
session.commit()
@staticmethod
def _scope_stmt(scope: AgentAppSessionScope):
return select(AgentRuntimeSession).where(
AgentRuntimeSession.owner_type == AgentRuntimeSessionOwnerType.CONVERSATION,
AgentRuntimeSession.tenant_id == scope.tenant_id,
AgentRuntimeSession.conversation_id == scope.conversation_id,
AgentRuntimeSession.agent_id == scope.agent_id,
AgentRuntimeSession.agent_config_snapshot_id == scope.agent_config_snapshot_id,
)
@classmethod
def _active_stmt(cls, scope: AgentAppSessionScope):
return cls._scope_stmt(scope).where(AgentRuntimeSession.status == AgentRuntimeSessionStatus.ACTIVE)
__all__ = ["AgentAppRuntimeSessionStore", "AgentAppSessionScope"]

View File

@ -112,16 +112,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
"created_at": chunk.created_at,
}
match sub_stream_response:
case MessageEndStreamResponse():
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
case ErrorStreamResponse():
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
case _:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -134,10 +134,6 @@ class AppQueueManager(ABC):
self._check_for_sqlalchemy_models(event.model_dump())
self._publish(event, pub_from)
def is_stopped(self) -> bool:
"""Return whether the current task has been manually stopped."""
return self._is_stopped()
@abstractmethod
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
@ -213,16 +209,14 @@ class AppQueueManager(ABC):
def _check_for_sqlalchemy_models(self, data: Any):
# from entity to dict or list
match data:
case dict():
for value in data.values():
self._check_for_sqlalchemy_models(value)
case list():
for item in data:
self._check_for_sqlalchemy_models(item)
case _:
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances that"
" cause thread safety issues is not allowed."
)
if isinstance(data, dict):
for value in data.values():
self._check_for_sqlalchemy_models(value)
elif isinstance(data, list):
for item in data:
self._check_for_sqlalchemy_models(item)
else:
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
)

View File

@ -305,27 +305,26 @@ class AppRunner:
text += message.content
elif isinstance(message.content, list):
for content in message.content:
match content:
case str():
text += content
case TextPromptMessageContent():
text += content.data
case ImagePromptMessageContent():
if message_id and user_id and tenant_id:
try:
self._handle_multimodal_image_content(
content=content,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
queue_manager=queue_manager,
)
except Exception:
_logger.exception("Failed to handle multimodal image output")
else:
_logger.warning("Received multimodal output but missing required parameters")
case _:
text += content.data if hasattr(content, "data") else str(content)
if isinstance(content, str):
text += content
elif isinstance(content, TextPromptMessageContent):
text += content.data
elif isinstance(content, ImagePromptMessageContent):
if message_id and user_id and tenant_id:
try:
self._handle_multimodal_image_content(
content=content,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
queue_manager=queue_manager,
)
except Exception:
_logger.exception("Failed to handle multimodal image output")
else:
_logger.warning("Received multimodal output but missing required parameters")
else:
text += content.data if hasattr(content, "data") else str(content)
if not model:
model = result.model

View File

@ -112,16 +112,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
"created_at": chunk.created_at,
}
match sub_stream_response:
case MessageEndStreamResponse():
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
case ErrorStreamResponse():
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
case _:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -276,18 +276,17 @@ class WorkflowResponseConverter:
created_by: CreatedByDict | dict[str, object] = {}
user = self._user
match user:
case Account():
created_by = AccountCreatedByDict(
id=user.id,
name=user.name,
email=user.email,
)
case EndUser():
created_by = EndUserCreatedByDict(
id=user.id,
user=user.session_id,
)
if isinstance(user, Account):
created_by = AccountCreatedByDict(
id=user.id,
name=user.name,
email=user.email,
)
elif isinstance(user, EndUser):
created_by = EndUserCreatedByDict(
id=user.id,
user=user.session_id,
)
return WorkflowFinishStreamResponse(
task_id=task_id,
@ -456,18 +455,17 @@ class WorkflowResponseConverter:
created_by: Mapping[str, object]
user = creator_user
match user:
case Account():
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
case _:
created_by = {
"id": user.id,
"user": user.session_id,
}
if isinstance(user, Account):
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
else:
created_by = {
"id": user.id,
"user": user.session_id,
}
return WorkflowFinishStreamResponse(
task_id=task_id,
@ -564,16 +562,15 @@ class WorkflowResponseConverter:
outputs, outputs_truncated = self._truncate_mapping(encoded_outputs)
metadata = self._merge_metadata(event.execution_metadata, snapshot)
match event:
case QueueNodeSucceededEvent():
status = WorkflowNodeExecutionStatus.SUCCEEDED
error_message = event.error
case QueueNodeFailedEvent():
status = WorkflowNodeExecutionStatus.FAILED
error_message = event.error
case _:
status = WorkflowNodeExecutionStatus.EXCEPTION
error_message = event.error
if isinstance(event, QueueNodeSucceededEvent):
status = WorkflowNodeExecutionStatus.SUCCEEDED
error_message = event.error
elif isinstance(event, QueueNodeFailedEvent):
status = WorkflowNodeExecutionStatus.FAILED
error_message = event.error
else:
status = WorkflowNodeExecutionStatus.EXCEPTION
error_message = event.error
return NodeFinishStreamResponse(
task_id=task_id,

View File

@ -109,18 +109,17 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
"created_at": chunk.created_at,
}
match sub_stream_response:
case MessageEndStreamResponse():
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
case ErrorStreamResponse():
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
case _:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast, override
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -15,7 +15,6 @@ from core.app.entities.task_entities import (
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
@classmethod
@override
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking full response.
@ -25,7 +24,6 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[Workflow
return dict(blocking_response.model_dump())
@classmethod
@override
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking simple response.
@ -35,7 +33,6 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[Workflow
return cls.convert_blocking_full_response(blocking_response)
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -69,7 +66,6 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[Workflow
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

View File

@ -200,21 +200,6 @@ class AgentChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGe
pass
class AgentAppGenerateEntity(ChatAppGenerateEntity):
"""
Agent App (new Agent app type) Generate Entity.
Subclasses ``ChatAppGenerateEntity`` so it rides the exact same EasyUI chat
pipeline (generator, task pipeline, message cycle) without widening every
accepted-entity union. The answer is produced by the dify-agent backend
rather than an in-process LLM call; ``model_conf`` is synthesized from the
bound Agent Soul model so the chat task pipeline can persist usage.
"""
agent_id: str
agent_config_snapshot_id: str
class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
Advanced Chat Application Generate Entity.

View File

@ -46,14 +46,13 @@ class BasedGenerateTaskPipeline:
e = event.error
err: Exception
match e:
case InvokeAuthorizationError():
err = InvokeAuthorizationError("Incorrect API key provided")
case InvokeError() | ValueError():
err = e
case _:
description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e))
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError):
err = e
else:
description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e))
if not message_id or not session:
return err

View File

@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Literal, override
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
from core.db.session_factory import session_factory
from core.file import remote_fetcher
from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.tools.signature import sign_tool_file
from core.workflow.file_reference import parse_file_reference
from extensions.ext_storage import storage
@ -46,7 +46,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
@override
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
return remote_fetcher.graphon_remote_file_fetcher.get(url, follow_redirects=follow_redirects)
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
@override
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:

View File

@ -91,28 +91,26 @@ class AppGeneratorTTSPublisher:
)
future_queue.put(futures_result)
break
else:
match message.event:
case QueueAgentMessageEvent() | QueueLLMChunkEvent():
message_content = message.event.chunk.delta.message.content
if not message_content:
continue
match message_content:
case str():
self.msg_text += message_content
case list():
for content in message_content:
if not isinstance(content, TextPromptMessageContent):
continue
self.msg_text += content.data
case QueueTextChunkEvent():
self.msg_text += message.event.text
case QueueNodeSucceededEvent():
if message.event.outputs is None:
continue
output = message.event.outputs.get("output", "")
if isinstance(output, str):
self.msg_text += output
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
message_content = message.event.chunk.delta.message.content
if not message_content:
continue
match message_content:
case str():
self.msg_text += message_content
case list():
for content in message_content:
if not isinstance(content, TextPromptMessageContent):
continue
self.msg_text += content.data
elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text
elif isinstance(message.event, QueueNodeSucceededEvent):
if message.event.outputs is None:
continue
output = message.event.outputs.get("output", "")
if isinstance(output, str):
self.msg_text += output
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.max_sentence, 7):

View File

@ -12,7 +12,7 @@ from uuid import uuid4
import httpx
from configs import dify_config
from core.file import remote_fetcher
from core.helper import ssrf_proxy
from extensions.ext_database import db
from extensions.ext_storage import storage
from extensions.storage.storage_type import StorageType
@ -44,6 +44,26 @@ class DatasourceFileManager:
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@staticmethod
def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
"""
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
recalculated_sign = hmac.new(
dify_config.SECRET_KEY.encode(),
data_to_sign.encode(),
hashlib.sha256,
).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
@staticmethod
def create_file_by_raw(
*,
@ -97,7 +117,7 @@ class DatasourceFileManager:
) -> ToolFile:
# try to download image
try:
response = remote_fetcher.make_request("GET", file_url)
response = ssrf_proxy.get(file_url)
response.raise_for_status()
blob = response.content
except httpx.TimeoutException:

View File

@ -1,5 +0,0 @@
"""File retrieval helpers shared by backend file-oriented workflows."""
from . import remote_fetcher
__all__ = ["remote_fetcher"]

View File

@ -1,345 +0,0 @@
"""Unified remote-file retrieval with Dify signed file URL resolution.
Use this module for backend workflows whose intent is to fetch remote file content
or remote file metadata from a URL, even when the URL originally came from a user
upload, a workflow variable, a tool/datasource file, or an app DSL. GET/HEAD
requests can resolve Dify-signed file URLs locally through DB + storage before
falling back to the SSRF-protected network client.
Use `core.helper.ssrf_proxy` directly only for generic outbound HTTP where the
URL is not being treated as a remote file, such as HTTP Request nodes, external
API integrations, auth discovery, or user-configured tool calls. Those calls must
stay as real network requests and should not reinterpret Dify file URLs as stored
files.
"""
from __future__ import annotations
import base64
import hashlib
import hmac
import re
import time
import urllib.parse
from dataclasses import dataclass
from typing import Any, Literal
import httpx
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController
from core.db.session_factory import session_factory
from core.helper import ssrf_proxy
from core.helper.ssrf_proxy import (
SSRF_DEFAULT_MAX_RETRIES,
_to_graphon_http_response,
max_retries_exceeded_error,
request_error,
)
from extensions.ext_storage import storage
from models import ToolFile, UploadFile
_UPLOAD_FILE_PATH_PATTERN = re.compile(
r"^/files/(?P<file_id>[a-fA-F0-9-]+)/(?P<preview_kind>file-preview|image-preview)$"
)
_TOOL_FILE_PATH_PATTERN = re.compile(r"^/files/tools/(?P<file_id>[a-fA-F0-9-]+)(?P<extension>\.[^/]*)?$")
_DATASOURCE_FILE_PATH_PATTERN = re.compile(r"^/files/datasources/(?P<file_id>[a-fA-F0-9-]+)(?P<extension>\.[^/]*)?$")
_file_access_controller = DatabaseFileAccessController()
@dataclass(frozen=True)
class _SignedFileUrl:
file_id: str
preview_kind: Literal["file-preview", "image-preview"]
record_kind: Literal["upload", "tool", "datasource"]
def make_request(method: str, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
"""Fetch remote file content or metadata.
GET and HEAD requests for Dify-owned signed file URLs are served from local
storage. Every other request is delegated unchanged to the SSRF proxy.
"""
normalized_method = method.upper()
if normalized_method == "GET":
response = _resolve_dify_signed_file_url("GET", url)
if response is not None:
return response
if normalized_method == "HEAD":
response = _resolve_dify_signed_file_url("HEAD", url)
if response is not None:
return response
return ssrf_proxy.make_request(method=method, url=url, max_retries=max_retries, **kwargs)
class GraphonRemoteFileFetcher:
"""Graphon HTTP-client adapter backed by the unified remote-file fetcher.
Graphon requires method-specific HTTP client methods, while regular Dify
call sites should use `make_request` directly.
"""
@property
def max_retries_exceeded_error(self) -> type[Exception]:
return max_retries_exceeded_error
@property
def request_error(self) -> type[Exception]:
return request_error
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
return _to_graphon_http_response(make_request("GET", url=url, max_retries=max_retries, **kwargs))
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
return _to_graphon_http_response(make_request("HEAD", url=url, max_retries=max_retries, **kwargs))
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
return _to_graphon_http_response(make_request("POST", url=url, max_retries=max_retries, **kwargs))
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
return _to_graphon_http_response(make_request("PUT", url=url, max_retries=max_retries, **kwargs))
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
return _to_graphon_http_response(make_request("DELETE", url=url, max_retries=max_retries, **kwargs))
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
return _to_graphon_http_response(make_request("PATCH", url=url, max_retries=max_retries, **kwargs))
def _resolve_dify_signed_file_url(method: Literal["GET", "HEAD"], url: str) -> httpx.Response | None:
parsed_url = urllib.parse.urlparse(url)
if not _is_dify_file_origin(parsed_url):
return None
signed_file_url = _parse_signed_file_path(parsed_url.path)
if signed_file_url is None:
return None
query = urllib.parse.parse_qs(parsed_url.query, keep_blank_values=True)
timestamp = _single_query_value(query, "timestamp")
nonce = _single_query_value(query, "nonce")
sign = _single_query_value(query, "sign")
if timestamp is None or nonce is None or sign is None:
return None
if not _verify_signed_file_url(
signed_file_url=signed_file_url,
timestamp=timestamp,
nonce=nonce,
sign=sign,
):
return None
if signed_file_url.record_kind == "upload":
return _build_upload_file_response(method=method, url=url, file_id=signed_file_url.file_id)
if signed_file_url.record_kind == "tool":
return _build_tool_file_response(method=method, url=url, file_id=signed_file_url.file_id)
return _build_datasource_file_response(method=method, url=url, file_id=signed_file_url.file_id)
def _parse_signed_file_path(path: str) -> _SignedFileUrl | None:
upload_match = _UPLOAD_FILE_PATH_PATTERN.match(path)
if upload_match:
preview_kind: Literal["file-preview", "image-preview"]
if upload_match.group("preview_kind") == "image-preview":
preview_kind = "image-preview"
else:
preview_kind = "file-preview"
return _SignedFileUrl(
file_id=upload_match.group("file_id"),
preview_kind=preview_kind,
record_kind="upload",
)
tool_match = _TOOL_FILE_PATH_PATTERN.match(path)
if tool_match:
return _SignedFileUrl(
file_id=tool_match.group("file_id"),
preview_kind="file-preview",
record_kind="tool",
)
datasource_match = _DATASOURCE_FILE_PATH_PATTERN.match(path)
if datasource_match:
return _SignedFileUrl(
file_id=datasource_match.group("file_id"),
preview_kind="file-preview",
record_kind="datasource",
)
return None
def _is_dify_file_origin(parsed_url: urllib.parse.ParseResult) -> bool:
if parsed_url.scheme not in {"http", "https"} or not parsed_url.hostname:
return False
url_origin = _origin_parts(parsed_url)
if url_origin is None:
return False
allowed_origins = {
origin
for configured_url in [dify_config.FILES_URL, dify_config.INTERNAL_FILES_URL]
if configured_url and (origin := _origin_parts(urllib.parse.urlparse(configured_url))) is not None
}
return url_origin in allowed_origins
def _origin_parts(parsed_url: urllib.parse.ParseResult) -> tuple[str, str, int] | None:
if parsed_url.scheme not in {"http", "https"} or not parsed_url.hostname:
return None
try:
port = parsed_url.port
except ValueError:
return None
return parsed_url.scheme, parsed_url.hostname.lower(), port or _default_port(parsed_url.scheme)
def _default_port(scheme: str) -> int:
return 443 if scheme == "https" else 80
def _single_query_value(query: dict[str, list[str]], key: str) -> str | None:
values = query.get(key)
if not values or len(values) != 1:
return None
return values[0]
def _verify_signed_file_url(
*,
signed_file_url: _SignedFileUrl,
timestamp: str,
nonce: str,
sign: str,
) -> bool:
try:
current_time = int(time.time())
signed_at = int(timestamp)
except ValueError:
return False
if current_time - signed_at > dify_config.FILES_ACCESS_TIMEOUT:
return False
payload = f"{signed_file_url.preview_kind}|{signed_file_url.file_id}|{timestamp}|{nonce}"
recalculated = hmac.new(dify_config.SECRET_KEY.encode(), payload.encode(), hashlib.sha256).digest()
expected = base64.urlsafe_b64encode(recalculated).decode()
return hmac.compare_digest(sign, expected)
def _build_upload_file_response(*, method: Literal["GET", "HEAD"], url: str, file_id: str) -> httpx.Response:
with session_factory.create_session() as session:
upload_file = _file_access_controller.get_upload_file(session=session, file_id=file_id)
if upload_file is None:
return _build_response(method=method, url=url, status_code=404)
content = b"" if method == "HEAD" else storage.load_once(upload_file.key)
return _build_response(
method=method,
url=url,
status_code=200,
content=content,
content_length=upload_file.size,
content_type=upload_file.mime_type,
filename=upload_file.name,
)
def _build_tool_file_response(*, method: Literal["GET", "HEAD"], url: str, file_id: str) -> httpx.Response:
with session_factory.create_session() as session:
tool_file = _file_access_controller.get_tool_file(session=session, file_id=file_id)
if tool_file is None:
return _build_response(method=method, url=url, status_code=404)
content = b"" if method == "HEAD" else storage.load_once(tool_file.file_key)
return _build_response(
method=method,
url=url,
status_code=200,
content=content,
content_length=tool_file.size,
content_type=tool_file.mimetype,
filename=tool_file.name,
)
def _build_datasource_file_response(*, method: Literal["GET", "HEAD"], url: str, file_id: str) -> httpx.Response:
with session_factory.create_session() as session:
upload_file = _file_access_controller.get_upload_file(session=session, file_id=file_id)
if upload_file is not None:
return _build_upload_file_record_response(method=method, url=url, upload_file=upload_file)
tool_file = _file_access_controller.get_tool_file(session=session, file_id=file_id)
if tool_file is not None:
return _build_tool_file_record_response(method=method, url=url, tool_file=tool_file)
return _build_response(method=method, url=url, status_code=404)
def _build_upload_file_record_response(
*,
method: Literal["GET", "HEAD"],
url: str,
upload_file: UploadFile,
) -> httpx.Response:
content = b"" if method == "HEAD" else storage.load_once(upload_file.key)
return _build_response(
method=method,
url=url,
status_code=200,
content=content,
content_length=upload_file.size,
content_type=upload_file.mime_type,
filename=upload_file.name,
)
def _build_tool_file_record_response(
*,
method: Literal["GET", "HEAD"],
url: str,
tool_file: ToolFile,
) -> httpx.Response:
content = b"" if method == "HEAD" else storage.load_once(tool_file.file_key)
return _build_response(
method=method,
url=url,
status_code=200,
content=content,
content_length=tool_file.size,
content_type=tool_file.mimetype,
filename=tool_file.name,
)
def _build_response(
*,
method: Literal["GET", "HEAD"],
url: str,
status_code: int,
content: bytes = b"",
content_length: int | None = None,
content_type: str | None = None,
filename: str | None = None,
) -> httpx.Response:
headers: dict[str, str] = {}
if content_type:
headers["Content-Type"] = content_type
if content_length is not None and content_length >= 0:
headers["Content-Length"] = str(content_length)
if filename:
headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{urllib.parse.quote(filename)}"
return httpx.Response(
status_code=status_code,
headers=headers,
content=content,
request=httpx.Request(method, url),
)
graphon_remote_file_fetcher = GraphonRemoteFileFetcher()

Some files were not shown because too many files have changed in this diff Show More