Compare commits

..

132 Commits

Author SHA1 Message Date
77a6bc28a7 docs(cli): document skills install in the README
Replace the `skill init` bundle description with the `skills install`
contract: dry-run default, --yes / --agent / <dir> / --stdout, and
config-directory agent detection.
2026-06-03 21:48:17 +08:00
5eb26785ba feat(cli/skills): add skills install with multi-agent detection
Install the embedded SKILL.md into local agents. Detection is by config
directory existence (no subprocess): Claude Code (~/.claude →
~/.claude/skills), Codex (~/.codex → ~/.agents/skills, per OpenAI docs),
opencode (~/.config/opencode → ~/.config/opencode/skills).

The command defaults to a dry-run that lists targets and writes nothing;
--yes writes to every detected agent and prints each path; --agent narrows
to a subset; a positional <dir> forces one explicit directory; --stdout
prints the skill for piping. Illegal flag combinations and unknown/undetected
agents exit 2.
2026-06-03 21:48:05 +08:00
1f2492d635 refactor(cli/skill): embed a single pure-delegation SKILL.md
Replace the runtime tree-walking skill generator with one hand-authored,
version-stamped SKILL.md embedded as a source constant. renderSkill() now
only substitutes {{VERSION}} — no command-tree walk, no Safety/Workflow
enumeration, no reference/*.md. The skill points agents at
`difyctl help -o json` for the live command surface, so it cannot drift
from the binary.

Retires `skill init`. skill.test.ts now guards the zero-enumeration
invariant (no command/flag listing leaks into the file) plus an exact
version-substitution golden.
2026-06-03 21:47:26 +08:00
23def1ef3f feat(cli/help): surface a GLOBAL FLAGS section in the top-level overview
Cross-command flags (`-o/--output`, `-v/--verbose`, `--http-retry`) were
invisible from `difyctl --help` — a human had to drill into a command or the
agent guide to learn the output selector exists. Add a GLOBAL FLAGS section to
the top-level overview, sourced from a single `GLOBAL_FLAG_HELP` list in the
help contract; `-o`'s accepted values are derived from `CONTRACT.outputFormats`
so they can never drift from the machine-readable contract.

Promote `helpGroup` to a typed field on `FlagDefinition` (it was an untyped tag
on `httpRetryFlag`) and mark the built-in `verbose` global flag with it, so the
cross-command intent is expressed in the type system.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-03 19:14:02 +08:00
831a53e088 feat(cli/help): full-depth top-level overview with usage and examples
Replace the bespoke two-level `printTopLevelHelp` loop with a full-depth
render derived from the same `collectCommands` walk the `-o json` site map
uses, so every runnable leaf — including third-level commands like
`auth devices list/revoke` — is visible and addressable by its canonical
`path.join(' ')` form. Leaves group by first segment with a blank line
between groups; `version` now reads as a leaf rather than a childless group.

Fold the human overview into `formatTopLevelHelp`'s text branch (mirroring
per-command `formatHelp`) and add gh-style USAGE, curated EXAMPLES, and a
LEARN MORE pointer block so a newcomer or agent can reach a first app run
from `--help` alone. The structured `-o json` branch is unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-03 19:08:52 +08:00
60295b66c3 fix(cli/help): drill into namespace nodes on <group> --help
Group nodes like `auth` and `auth devices` have no command of their own
(no index.ts), so `resolveCommand` missed them and `difyctl auth --help`
errored with "unknown help topic" — leaving `auth devices list/revoke`
unreachable from help entirely.

Add a strict-prefix match over the full-depth `collectCommands` walk so a
namespace path renders its subtree instead of erroring, mirroring gh's
`gh pr --help`. Introduce a shared `renderCommandRows` helper (aligned
`path  description` rows, grouped by first segment) that the top-level
overview will also reuse.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-03 19:06:02 +08:00
9727f7a94a refactor(cli/help): type byEffect param as CommandEffect
Tighten the safety-block helper to accept CommandEffect instead of a
bare string so effect literals are checked at compile time.
2026-06-03 17:51:48 +08:00
2d44910bf8 refactor(cli): drop hand-written flag lists from agentGuides
The ARGUMENTS & FLAGS / FILTERS / FLAGS blocks in the resume/get/run app
guides duplicated each command's static args/flags, which formatHelp
already renders authoritatively (and help -o json exposes as structured
arrays). The copies had begun to drift and listed only a subset. Remove
them so flag definitions are the single source; keep only the
non-derivable narrative. Migrate the two semantics that lived only in the
run app guide — --file remote-URL support and --conversation being
chat/advanced-chat only — into the flag descriptions themselves.
2026-06-03 17:51:48 +08:00
99174f90c8 docs(cli): surface the agent help topic in README
List `difyctl help agent` among the background docs and point the
"for agents" paragraph at it as the cross-command operating guide —
the topic added in the help-system work was reachable but unadvertised.
2026-06-03 17:51:48 +08:00
db3af89a4c feat(cli/skill): generate the agent skill at runtime via skill init
Pivot the agent skill from a checked-in, dev-time artifact to a pure
runtime render from the binary (HELP-SPEC D3.3/D3.4).

D3.3 — drop the generate-and-commit machinery:
- remove the committed skill/ bundle (SKILL.md + reference/*.md)
- delete scripts/generate-skill.ts and the skill:gen/skill:check scripts
  (and skill:check from the ci chain)
- drop the now-unused skill/** eslint ignore
- the existing renderSkill snapshot test is the determinism guard

D3.4 — add `difyctl skill init [dir]` as the single distribution path:
- renders from the local binary, so version-consistent by construction
- no dir -> ./.claude/skills/difyctl/; --user -> ~/.claude/skills/difyctl/;
  passing both is a usage error (exit 2); prints the path it wrote
- effect: write, surfaced in help -o json and the skill Safety section;
  safetyFraming broadened to "remote state or local files" to stay honest
2026-06-03 17:51:48 +08:00
a3f3a8d169 feat(cli/help): surface per-command output formats in machine-readable help
describeCommand dropped the output flag's allowed-value set, so
`difyctl help -o json` (the agent's discovery channel) omitted which
formats each leaf command actually supports — agents had to scrape the
flag description string, and the contract's union list invited
unsupported guesses like `version -o name`.

Promote the existing FlagDefinition.options onto FlagDescriptor so it
flows into all three help renderers. Generic over any enum flag, not an
output-format special case. Add a contract note that format support is
per-command.
2026-06-03 17:51:48 +08:00
1c0a3b3351 feat(cli): generate agent skill from the help single-source (phase 5)
Render the agent skill as a third view of the same data that feeds
`--help` text and `-o json`: SKILL.md (contract + login→get→describe→
run→resume golden path with HITL) plus reference/{account,external,
environment}.md, checked in and drift-gated by `skill:check`.

- add `effect: read|write|destructive` to the command model; surface it
  in `--help -o json` / `help -o json`; backfill write/destructive cmds
- src/help/skill.ts: pure renderSkill(tree); shell (trigger, opening,
  safety framing, golden-path order) is the only hand-authored part
- scripts/generate-skill.ts: bun generator isomorphic with tree:gen
- wire skill:gen / skill:check into ci; eslint-ignore generated skill/

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-03 17:51:48 +08:00
50b67efacb docs(cli): remove unified help system spec
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-03 17:51:47 +08:00
919f56ebee feat(cli): add agent topic and backfill command agentGuides
Add an `agent` help topic that consolidates the cross-command contract
(output, discovery, auth, exit codes, error envelope, HITL, retry),
deriving the volatile parts from CONTRACT so it cannot drift. Backfill
agentGuide on the commands an agent chains — get/describe/run/resume app
and auth login — following the run/app guide.ts convention. A wiring
guard test asserts each exposes a non-empty guide.

Per D1=A, agentGuide stays a single string: it appends to human --help
and surfaces in `--help -o json` / `help -o json` as-is.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-03 17:51:47 +08:00
10d9a6bc49 fix(cli): drop trailing whitespace on group rows in top-level help
Group nodes without a description (e.g. `auth devices`) printed a
trailing double-space. Emit the bare name when there is no description.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-03 17:51:47 +08:00
06d886d4ba feat(cli): machine-readable help via -o json
Make the help renderer format-aware. `difyctl <cmd> --help -o json`
returns a structured command descriptor (description, args, flags,
examples, agentGuide), and `difyctl help -o json` emits the full
command tree plus a global contract block (exit codes derived from
ExitCode, output formats, the stderr error envelope, and the HITL
pause/resume protocol) — the site map an agent needs in one call.

Adds src/help/contract.ts and registry.collectCommands(); resolves the
help path from leading positional tokens so output flags no longer leak
into resolution. Also tidies flag labels (aliases comma-separated, then
`<type>` after a space) as a side effect of the help.ts rewrite.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-03 17:51:47 +08:00
b77a242b13 refactor(cli): make help a verb, move topics to registry
Concept guides (account / external / environment) were modeled as
commands under `src/commands/help/*`, but `run.ts` intercepts
`argv[0] === 'help'` as a verb, so `difyctl help account` never reached
them — it fell through to the top-level list. The guides were therefore
unreachable yet still advertised.

Move the guides into a data-driven `src/help/topics.ts` registry
(mirroring `ENV_REGISTRY`) and rewrite the help branch to resolve in
order: command -> topic -> suggestion -> overview. The top-level help
now prints a GUIDES section, and `help <unknown>` exits non-zero with
suggestions. `commands/help/*` is removed from the command tree.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-03 17:51:47 +08:00
9f01fbf770 docs(cli): add unified help system spec
Record the design for making `help` a real verb, migrating concept
guides to a topic registry, and exposing machine-readable help for
agents via `-o json`. Locks D1=A (agentGuide stays a single string)
and D2=registry; defers phase 3 content backfill and D3 (skill gen).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-03 17:51:46 +08:00
yyh
5c7f05bd10 fix(web): auth form state management (#37003) 2026-06-03 09:14:01 +00:00
02e1a60cde chore: add missing @override decorator to api/configs (#37006)
Co-authored-by: mac <mac@1234.local>
2026-06-03 09:11:50 +00:00
57b573d02b refactor(api): migrate tenant/user via DI for several endpoints (#37004) 2026-06-03 08:59:00 +00:00
yyh
9de40e8f21 chore: update Claude skill links (#36997) 2026-06-03 08:00:35 +00:00
cad0942f4d fix(api): enforce workspace membership + role checks in auth pipeline (#36931)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 07:31:47 +00:00
cb9b1b593e feat: add Milvus TLS env examples (#36980) 2026-06-03 07:16:18 +00:00
2a8bdc2373 fix: pydantic_core._pydantic_core.ValidationError: 2 validation errors for DatasetDetailResponse (#36753) 2026-06-03 07:10:55 +00:00
ee6a07d13c refactor: use explicit session in inner api user auth (#36995) 2026-06-03 07:06:38 +00:00
yyh
2d6c9300e3 fix(api): tighten agent v2 generated contracts (#36989)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 06:52:40 +00:00
d6b4c800c2 refactor(web): migrate account education notice storage (#36991)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 06:39:22 +00:00
yyh
1b37635f92 fix: configure server console api url (#36958)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 06:22:46 +00:00
86af36429d fix: create app from template modal has no backdrop (#36987)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 06:14:46 +00:00
b96ea94505 chore: add :str to <path: parameter (#36913)
Co-authored-by: 99 <wh2099@pm.me>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 05:25:11 +00:00
d649cccda0 chore: add missing @override decorato to api/extensions (#36941)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-06-03 05:25:08 +00:00
5cbbd78f38 refactor(web): migrate chat sidebar collapse storage (#36963)
Co-authored-by: lmlm <7487674+popsiclelmlm@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 04:40:48 +00:00
5a0ad4ecd9 fix: normalize json_schema from string to dict in VariableEntity (#36777) 2026-06-03 04:33:25 +00:00
1e76b9e1b8 refactor(web): migrate workflow-node-panel-width to useSetLocalStorage (#36983)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 04:32:41 +00:00
1b972c4e09 refactor(api): migrate tenant/user via DI for several endpoints (#36971)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 04:24:17 +00:00
7968d2c3c8 refactor(web): migrate workflow-variable-inpsect-panel-height to useSetLocalStorage (#36982)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 03:48:59 +00:00
7507e9ba67 refactor(web): migrate debug-and-preview-panel-width to useSetLocalStorage (#36977)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 03:27:15 +00:00
y
ca31762e26 refactor(web): migrate education verifying storage to useLocalStorage (#36934)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 02:16:59 +00:00
f591da7865 ci: ruff cover agent (#36949) 2026-06-02 11:40:19 +00:00
f19679b217 refactor: improve network error and allow verbose output (#36923) 2026-06-02 10:43:40 +00:00
b682591c7a refactor(web): migrate question classifier label hint storage (#36932)
Co-authored-by: lmlm <7487674+popsiclelmlm@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 10:28:50 +00:00
8f6b59feff refactor(web): migrate rag recommendations collapsed storage (#36940)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 09:08:51 +00:00
99833f65d8 refactor(web): migrate NEED_REFRESH_APP_LIST_KEY to useLocalStorage/useSetLocalStorage (#36908)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
2026-06-02 08:41:01 +00:00
yyh
696fc5c213 refactor(web): manage goto anything open state with atom (#36938) 2026-06-02 08:23:18 +00:00
eae44cfecb feat(dify-agent): add shell layer (#36838) 2026-06-02 07:54:52 +00:00
yyh
dea4e66456 fix(web): use generated account-profile contracts (#36927)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 07:28:05 +00:00
3cd0da303a refactor: remove unused Flask-RESTX field dicts from end_user and conversation_variable fields (#28015) (#36929) 2026-06-02 07:27:23 +00:00
888483a2f8 fix: user token (#36930) 2026-06-02 07:20:07 +00:00
7056985f72 refactor: inject current user id in stop message endpoints (#36925)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 06:48:10 +00:00
6ce61eae59 fix(cli): invalidate app metadata cache on 422 to clear stale data (#36921) 2026-06-02 05:20:33 +00:00
yyh
079af312c6 fix(contracts): include account avatar url in profile schema (#36924)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 04:30:47 +00:00
0da13dfe4d refactor(cli): unify token storage behind Store + add host/account switching (#36830) 2026-06-02 04:05:53 +00:00
1ff4d75084 refactor(web): migrate anthropic quota notice storage (#36922)
Co-authored-by: lmlm <7487674+popsiclelmlm@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2026-06-02 04:05:15 +00:00
e35d23c3cb feat(api): Agent App type S1 — AppMode.AGENT + create flow + binding (#36829)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 03:50:10 +00:00
e530e84772 refactor(web): migrate NOTE_SHOW_AUTHOR_STORAGE_KEY to useLocalStorage/useSetLocalStorage (#36915)
Signed-off-by: Cocoon-Break <54054995+kuishou68@users.noreply.github.com>
Co-authored-by: lingxiu58 <86288566+lingxiu58@users.noreply.github.com>
Co-authored-by: pojian68 <232320289+pojian68@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2026-06-02 03:44:47 +00:00
2257a4f1ef refactor(web): migrate workflow featured collapsed storage (#36918)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 03:40:59 +00:00
yyh
f465dc5090 fix(web): defer react-scan loader (#36920) 2026-06-02 03:34:55 +00:00
5c1cfe6ada chore: ignore .vinext (#36914) 2026-06-02 02:43:15 +00:00
8d401d84c7 chore(api): adjust migration timestamp metadata for a1b2c3d4e5f6 (#36910) 2026-06-02 02:22:47 +00:00
b74287c2ab chore: update deps (#36911) 2026-06-02 02:10:59 +00:00
c64d3e98c4 fix(tools): use short-lived sessions for icon lookups to prevent idle-in-transaction (#36903)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 01:59:10 +00:00
yyh
a3265f722e docs: add client state guidelines (#36900) 2026-06-01 11:44:50 +00:00
5658065b97 test: satisfy strict pyrefly for migrated container tests (#36791) 2026-06-01 11:09:40 +00:00
yyh
8fc2807194 feat(web): create system-features vertical (#36894) 2026-06-01 10:15:25 +00:00
fc7716704d chore: not request system-features for cloud edition (#36891)
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
2026-06-01 09:31:16 +00:00
71ffaacb58 fix(api): centralize remote file retrieval (#36399)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-01 09:25:08 +00:00
cfc1cf2b8c refactor(cli/http): replace ky with a self-contained HTTP client (#36711)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-01 09:04:42 +00:00
yyh
055d9b9f0a refactor(web): migrate local storage hook usage (#36890) 2026-06-01 08:20:13 +00:00
yyh
21711bebeb refactor(web): migrate local storage access to react hook (#36888) 2026-06-01 07:57:54 +00:00
yyh
becccbf288 fix(web): read pnpm config env in standalone start (#36887) 2026-06-01 07:18:50 +00:00
86497045c9 feat: per-credential visibility control for plugin credentials (#35468)
Co-authored-by: Yang <yang@Yangs-MacBook-Pro.local>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-01 05:56:18 +00:00
687a177b24 chore: add override decorators to core repositories (#36885) 2026-06-01 05:24:21 +00:00
4a6d278354 refactor(web): mark workflow run props readonly (#36857) 2026-06-01 05:06:21 +00:00
yyh
7d69302e9f chore: update deps (#36884) 2026-06-01 04:28:04 +00:00
yyh
bcd573e560 fix(web): respect marketplace feature flag in model selector (#36883) 2026-06-01 04:11:58 +00:00
yyh
07c0c4e7b1 chore(web): remove TanStack devtools (#36882) 2026-06-01 03:57:50 +00:00
yyh
a8a2ca7b98 chore(cli): move eslint config into cli package (#36878) 2026-06-01 03:54:14 +00:00
de47d43b65 refactor: convert isinstance chains to match/case syntax (#36862)
Co-authored-by: krishkantiuj-ren <hiccup.cc.3@gmail.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-06-01 03:45:19 +00:00
240912cef5 fix(api): preserve hierarchical estimate rules (#36852)
Co-authored-by: root <kinsonnee@gmail.com>
2026-06-01 03:16:09 +00:00
72e040ead3 docs: add security policy (#36873) 2026-06-01 09:58:32 +08:00
c0ee821d45 refactor: use absolute path for inter dir importing (#36822)
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-06-01 01:32:16 +00:00
c7c3296572 fix: MCP search results include only MCP providers (#36871)
Co-authored-by: LL201314-II <you@example.com>
2026-06-01 01:13:51 +00:00
e7be04fd58 fix(api): dedup EndUser in plugin get_user by session_id for Reverse Invocation (#36742) 2026-06-01 00:57:29 +00:00
df6b5be50a refactor: convert isinstance chains to match/case (part 5) (#36503) 2026-05-31 15:08:59 +00:00
8e5f09091b refactor: convert if isinstance chains to match case (#36846)
Co-authored-by: duongynhi000005-oss <duongynhi000005-oss@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-05-31 15:05:43 +00:00
0a3005701f refactor: inject current user into user-only controllers (#36754)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-05-31 15:03:15 +00:00
d8571ce965 refactor: convert isinstance chains to match/case (part 4) (#36274)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-05-31 14:44:17 +00:00
f241ae25be fix: #36585 dep inject current user id (#36845)
Co-authored-by: duongynhi000005-oss <duongynhi000005-oss@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-31 14:37:39 +00:00
c6474a2a8b refactor: convert isinstance chains to match/case (part 8) (#36869) 2026-05-31 14:11:05 +00:00
yyh
480d05bc48 fix(web): prefetch workspace and guard routes with contract query (#36870) 2026-05-31 14:02:00 +00:00
yyh
f75725ccd9 feat(web): add server oRPC client (#36856) 2026-05-31 13:14:28 +00:00
yyh
2fe8c48255 refactor(web): scope workflow hotkeys (#36860)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-31 13:14:13 +00:00
ec5404cc9d chore: split trial models to a single API (#36796)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-31 13:09:13 +00:00
yyh
20f62b9919 fix(web): use generated current workspace query (#36843)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-31 13:04:18 +00:00
04f5555580 chore: split to single app_dsl_version API (#36864)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-31 12:13:44 +00:00
129af96c23 chore: add missing @override decorators to pipeline WorkflowAppGenerateResponseConverter (#36859)
Co-authored-by: krishkantiuj-ren <hiccup.cc.3@gmail.com>
2026-05-31 12:02:17 +00:00
df40960f5d chore: dep inject for model (#36750)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
2026-05-30 17:40:46 +00:00
599960024d refactor(api): migrate console/service_api.dataset.document to BaseModel (#36506)
Co-authored-by: WH-2099 <wh2099@pm.me>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-30 14:38:27 +00:00
yyh
6805d9bfc0 fix(auth): reset profile query after login (#36851) 2026-05-30 14:34:04 +00:00
928f888ef5 refactor(api): migrate console/service_api.dataset.segment to BaseModel (#36522)
Co-authored-by: WH-2099 <wh2099@pm.me>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-30 13:54:01 +00:00
yyh
f46c03460e fix(auth): avoid leaking request origin in refresh redirects (#36847) 2026-05-30 05:55:18 +00:00
0b60338ad5 chore: reuse injected SQLAlchemy sessions in app read paths (#36798) 2026-05-30 00:23:58 +00:00
yyh
91ac465982 fix(web): use default profile query cache (#36832) 2026-05-29 14:18:39 +00:00
yyh
9490d63c50 refactor(web): remove app initializer and move auth boot logic to route boundaries (#36818) 2026-05-29 12:26:34 +00:00
ae538ced47 chore: using single SSH_SCRIPT for saas dev (#36827) 2026-05-29 10:07:15 +00:00
487249728b fix: remove unnecessary # type: ignore comments (#24494) (#36825) 2026-05-29 09:41:32 +00:00
372a2e3e9c refactor: convert isinstance chains to match/case (part 7) (#35902) (#36826) 2026-05-29 09:40:33 +00:00
4939a9c33d refactor: add ts common style check for web and cli (#36823) 2026-05-29 09:26:32 +00:00
b6f92f1dc4 fix(cli): fix style (#36821) 2026-05-29 08:34:36 +00:00
ce276573a8 chore: deploy saas dev workflow (#36819) 2026-05-29 08:30:55 +00:00
5070cc9668 refactor(cli): optimize error handling in flag parsing (#36810) 2026-05-29 07:39:26 +00:00
a392a72960 chore: not store search tag condition in url (#36814) 2026-05-29 07:30:35 +00:00
30270b5c30 fix(device): surface SSO errors on /device and fix CLI null-account crash on external-SSO login (#36781) 2026-05-29 06:51:34 +00:00
24715a9570 chore: unified plugin status icon position (#36816) 2026-05-29 06:45:25 +00:00
c530a5d272 fix(api): validate annotation list pagination query (#36807)
Co-authored-by: root <kinsonnee@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-29 06:25:48 +00:00
418ee7398e fix: install failed plugin dose not show icon (#36811) 2026-05-29 06:07:43 +00:00
78f40c0d25 test: stabilize modal context pricing test (#36524) 2026-05-29 05:19:37 +00:00
2cc567c6a3 feat: add DTO for agent api (#36797)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-29 03:36:41 +00:00
a180ab19e4 chore: type check test container tests (#36790) 2026-05-29 01:54:25 +00:00
13eaa436e7 test: isolate Redis state in container tests (#36740) 2026-05-28 12:42:25 +00:00
3596d12e4c refactor(cli): use Store interface as token storage (#36726) 2026-05-28 10:02:51 +00:00
e8de10a3b5 feat(docker): add missing OPENAPI_* env vars to shared.env.example (#36752) 2026-05-28 08:52:03 +00:00
f5ab5e7eb3 fix: fix cannot extract elements from a scalar (#36769) 2026-05-28 07:31:36 +00:00
0c40e1c2a0 feat: add cross-environment app migration workflow (#36765)
Co-authored-by: XW <wei.xu1@wiz.ai>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-28 07:30:33 +00:00
c29d76757e docs(api): fix typo in vector migration docstrings (#36741) 2026-05-28 07:15:34 +00:00
91c1d3ad81 fix: handle null plugin badges (#36767) 2026-05-28 07:00:32 +00:00
57b02e341c refactor: add @override decorators to storage backend subclasses (#36406) (#36755) 2026-05-28 06:04:47 +00:00
b94ff65e9f fix(docker): copy dify-agent source into production stage (#36757) 2026-05-28 06:01:11 +00:00
678260e34e test: migrate workspace members tests to containers (#36738)
Co-authored-by: jamesrayammons <63717587+jamesrayammons@users.noreply.github.com>
2026-05-28 06:01:05 +00:00
739e34d08a fix(docker): pin web docker node version (#36756) 2026-05-28 05:25:41 +00:00
825fb9cb89 chore(codeowners): add Riskey for service API docs (#36731) 2026-05-28 05:06:12 +00:00
982 changed files with 39341 additions and 14696 deletions

View File

@ -1 +0,0 @@
../../.agents/skills/frontend-query-mutation

View File

@ -0,0 +1 @@
../../.agents/skills/how-to-write-component

1
.github/CODEOWNERS vendored
View File

@ -166,6 +166,7 @@
# Frontend - App - API Documentation
/web/app/components/develop/ @JzoNgKVO @iamjoel
/web/app/components/develop/template/*.mdx @JzoNgKVO @iamjoel @RiskeyL
# Frontend - App - Logs and Annotations
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel

View File

@ -51,6 +51,15 @@ jobs:
with:
files: |
api/**
- name: Check dify-agent inputs
if: github.event_name != 'merge_group'
id: dify-agent-changes
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
dify-agent/**/*.py
dify-agent/pyproject.toml
dify-agent/uv.lock
- if: github.event_name != 'merge_group'
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
@ -76,6 +85,17 @@ jobs:
# Format code
uv run ruff format ..
- if: github.event_name != 'merge_group' && steps.dify-agent-changes.outputs.any_changed == 'true'
run: |
cd dify-agent
uv sync --dev
# fmt first to avoid line too long
uv run ruff format .
# Fix lint errors
uv run ruff check --fix .
# Format code
uv run ruff format .
- name: count migration progress
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
run: |

View File

@ -1,4 +1,4 @@
name: Deploy Agent Dev
name: Deploy SaaS
permissions:
contents: read
@ -7,7 +7,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/agent-dev"
- "deploy/saas"
types:
- completed
@ -16,13 +16,13 @@ jobs:
runs-on: depot-ubuntu-24.04
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/agent-dev'
github.event.workflow_run.head_branch == 'deploy/saas'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
with:
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
host: ${{ secrets.SAAS_DEV_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
${{ vars.SSH_SCRIPT_SAAS_DEV || secrets.SSH_SCRIPT_SAAS_DEV }}

View File

@ -95,6 +95,51 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web
- name: Web tsslint
if: steps.changed-files.outputs.any_changed == 'true'
env:
NODE_OPTIONS: --max-old-space-size=4096
run: vp run lint:tss
- name: Web dead code check
if: steps.changed-files.outputs.any_changed == 'true'
run: vp run knip
ts-common-style:
name: TS Common
runs-on: depot-ubuntu-24.04
permissions:
checks: write
pull-requests: read
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
web/**
cli/**
e2e/**
sdks/nodejs-client/**
packages/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.nvmrc
eslint.config.mjs
.github/workflows/style.yml
.github/actions/setup-web/**
- name: Setup web environment
if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web
- name: Restore ESLint cache
if: steps.changed-files.outputs.any_changed == 'true'
id: eslint-cache-restore
@ -105,28 +150,14 @@ jobs:
restore-keys: |
${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
- name: Web style check
- name: Style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: .
run: vp run lint:ci
- name: Web tsslint
- name: Type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
env:
NODE_OPTIONS: --max-old-space-size=4096
run: vp run lint:tss
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: .
run: vp run type-check
- name: Web dead code check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: vp run knip
- name: Save ESLint cache
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5

3
.gitignore vendored
View File

@ -259,3 +259,6 @@ scripts/stress-test/reports/
.qoder/*
.context/
.eslintcache
# Vitest local reports
web/.vitest-reports/

27
SECURITY.md Normal file
View File

@ -0,0 +1,27 @@
# Security Policy
## Reporting a Vulnerability
If you believe you have found a security vulnerability in Dify, please report it privately through GitHub Security Advisories:
https://github.com/langgenius/dify/security/advisories/new
Please do not report security vulnerabilities through public GitHub issues, discussions, or pull requests.
When submitting a report, include as much relevant information as you can safely provide, such as:
- A description of the vulnerability
- Steps to reproduce, if safe to share privately
- Affected components, versions, or configurations
- Potential impact
- Any suggested mitigation or fix, if available
The maintainers will review reports submitted through GitHub Security Advisories and coordinate follow-up there.
## Public Disclosure
Please avoid publicly disclosing details of a vulnerability until it has been reviewed and, where appropriate, a fix or mitigation has been made available.
## Security Updates
Security fixes may be released through normal project releases or other appropriate channels. Users are encouraged to keep Dify deployments up to date.

View File

@ -27,7 +27,7 @@ COPY api/providers ./providers
COPY dify-agent/pyproject.toml dify-agent/README.md /app/dify-agent/
COPY dify-agent/src /app/dify-agent/src
# Trust the checked-in lock during image builds; local path sources are copied from the repository context.
RUN uv sync --frozen --no-dev
RUN uv sync --frozen --no-dev --no-editable
# production stage
FROM base AS production

View File

@ -34,6 +34,7 @@ from clients.agent_backend.request_builder import (
DIFY_PLUGIN_TOOLS_LAYER_ID,
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
WORKFLOW_USER_PROMPT_LAYER_ID,
AgentBackendAgentAppRunInput,
AgentBackendModelConfig,
AgentBackendOutputConfig,
AgentBackendRunRequestBuilder,
@ -49,6 +50,7 @@ __all__ = [
"DIFY_PLUGIN_TOOLS_LAYER_ID",
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
"WORKFLOW_USER_PROMPT_LAYER_ID",
"AgentBackendAgentAppRunInput",
"AgentBackendError",
"AgentBackendHTTPError",
"AgentBackendInternalEvent",

View File

@ -45,6 +45,7 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
AGENT_APP_USER_PROMPT_LAYER_ID = "agent_app_user_prompt"
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
DIFY_PLUGIN_TOOLS_LAYER_ID = "tools"
@ -181,9 +182,138 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
return value
class AgentBackendAgentAppRunInput(BaseModel):
"""Inputs to build one Agent App conversation-turn run request.
Unlike the workflow-node input there is no workflow-node-job prompt and no
previous-node context: the user prompt is the chat message, and multi-turn
continuity comes from ``session_snapshot`` + the history layer keyed by the
conversation.
"""
model: AgentBackendModelConfig
execution_context: DifyExecutionContextLayerConfig
user_prompt: str
agent_soul_prompt: str | None = None
purpose: RunPurpose = "agent_app"
idempotency_key: str | None = None
output: AgentBackendOutputConfig | None = None
tools: DifyPluginToolsLayerConfig | None = None
session_snapshot: CompositorSessionSnapshot | None = None
include_history: bool = True
suspend_on_exit: bool = True
metadata: dict[str, JsonValue] = Field(default_factory=dict)
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
@field_validator("user_prompt")
@classmethod
def _reject_blank_prompt(cls, value: str) -> str:
if not value.strip():
raise ValueError("prompt must not be blank")
return value
class AgentBackendRunRequestBuilder:
"""Converts API product state into the public ``dify-agent`` run protocol."""
def build_for_agent_app(self, run_input: AgentBackendAgentAppRunInput) -> CreateRunRequest:
"""Build an Agent App conversation-turn run request.
Layer graph: optional Agent Soul system prompt → user prompt →
execution context → optional history (multi-turn) → LLM → optional
plugin tools → optional structured output. Mirrors the workflow-node
layer ordering minus the workflow-job / previous-node prompt.
"""
layers: list[RunLayerSpec] = []
if run_input.agent_soul_prompt:
layers.append(
RunLayerSpec(
name=AGENT_SOUL_PROMPT_LAYER_ID,
type=PLAIN_PROMPT_LAYER_TYPE_ID,
metadata={**run_input.metadata, "origin": "agent_soul"},
config=PromptLayerConfig(prefix=run_input.agent_soul_prompt),
)
)
layers.extend(
[
RunLayerSpec(
name=AGENT_APP_USER_PROMPT_LAYER_ID,
type=PLAIN_PROMPT_LAYER_TYPE_ID,
metadata={**run_input.metadata, "origin": "agent_app_user_prompt"},
config=PromptLayerConfig(user=run_input.user_prompt),
),
RunLayerSpec(
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=run_input.execution_context,
),
]
)
if run_input.include_history:
layers.append(
RunLayerSpec(
name=DIFY_AGENT_HISTORY_LAYER_ID,
type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID,
metadata={**run_input.metadata, "origin": "agent_session_history"},
)
)
layers.append(
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
metadata=run_input.metadata,
config=DifyPluginLLMLayerConfig(
plugin_id=run_input.model.plugin_id,
model_provider=run_input.model.model_provider,
model=run_input.model.model,
credentials=run_input.model.credentials,
model_settings=run_input.model.model_settings or None,
),
)
)
if run_input.tools is not None and run_input.tools.tools:
layers.append(
RunLayerSpec(
name=DIFY_PLUGIN_TOOLS_LAYER_ID,
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
metadata=run_input.metadata,
config=run_input.tools,
)
)
if run_input.output is not None:
layers.append(
RunLayerSpec(
name=DIFY_AGENT_OUTPUT_LAYER_ID,
type=DIFY_OUTPUT_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=DifyOutputLayerConfig(
json_schema=run_input.output.json_schema,
description=run_input.output.description,
strict=run_input.output.strict,
),
)
)
return CreateRunRequest(
composition=RunComposition(layers=layers),
purpose=run_input.purpose,
idempotency_key=run_input.idempotency_key,
metadata=run_input.metadata,
session_snapshot=run_input.session_snapshot,
on_exit=LayerExitSignals(
default=ExitIntent.SUSPEND if run_input.suspend_on_exit else ExitIntent.DELETE,
),
)
def build_cleanup_request(
self,
*,

View File

@ -4,6 +4,12 @@ CLI command modules extracted from `commands.py`.
from .account import create_tenant, reset_email, reset_password
from .data_migrate import data_migrate, legacy_model_types
from .data_migration import (
export_migration_data,
export_migration_data_template,
import_migration_data,
migration_data_wizard,
)
from .plugin import (
extract_plugins,
extract_unique_plugins,
@ -26,7 +32,12 @@ from .retention import (
restore_workflow_runs,
)
from .storage import clear_orphaned_file_records, file_usage, migrate_oss, remove_orphaned_files_on_storage
from .system import convert_to_agent_apps, fix_app_site_missing, reset_encrypt_key_pair, upgrade_db
from .system import (
convert_to_agent_apps,
fix_app_site_missing,
reset_encrypt_key_pair,
upgrade_db,
)
from .vector import (
add_qdrant_index,
migrate_annotation_vector_database,
@ -48,10 +59,13 @@ __all__ = [
"data_migrate",
"delete_archived_workflow_runs",
"export_app_messages",
"export_migration_data",
"export_migration_data_template",
"extract_plugins",
"extract_unique_plugins",
"file_usage",
"fix_app_site_missing",
"import_migration_data",
"install_plugins",
"install_rag_pipeline_plugins",
"legacy_model_types",
@ -59,6 +73,7 @@ __all__ = [
"migrate_data_for_plugin",
"migrate_knowledge_vector_database",
"migrate_oss",
"migration_data_wizard",
"old_metadata_migration",
"remove_orphaned_files_on_storage",
"reset_email",

View File

@ -0,0 +1,754 @@
from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from typing import Any, cast
from uuid import UUID
import click
import sqlalchemy as sa
import yaml
from extensions.ext_database import db
from models import Tenant
from models.model import App
from models.tools import ApiToolProvider, MCPToolProvider, WorkflowToolProvider
from services.app_dsl_service import AppDslService
from services.data_migration.dependency_discovery_service import DependencyDiscoveryService
from services.data_migration.entities import (
DependencyKind,
ImportOptions,
MigrationDataError,
ReportContext,
ResourceReportItem,
)
from services.data_migration.export_service import ExportConfigParser, MigrationExportService
from services.data_migration.import_service import ImportRequest, MigrationImportService
from services.data_migration.package_service import MigrationPackageService
from services.data_migration.report_service import MigrationReportService
ID_STRATEGY_CHOICES = ["preserve-id", "generate-new-id"]
CONFLICT_STRATEGY_CHOICES = ["fail", "skip", "update"]
SUPPORTED_WIZARD_APP_MODES = ["workflow", "advanced-chat"]
WizardToolMap = dict[str, dict[str, str | None]]
WizardToolSelection = dict[str, list[str]]
def _scripted_export_template() -> dict[str, Any]:
return {
"source_tenant": {
"mode": "single",
"id": "",
"name": "admin's Workspace",
},
"apps": {
"modes": ["workflow", "advanced-chat"],
"ids": [],
"all": True,
},
"include_referenced_tools": True,
"additional_tools": {
"api_tools": [],
"workflow_tools": [],
"mcp_tools": [],
},
"include_secrets": False,
"import_options": {
"create_app_api_token_on_import": False,
"id_strategy": "preserve-id",
"conflict_strategy": "fail",
},
}
@click.command("app-migration-template", help="Print or write a scripted export config JSON template.")
@click.option(
"--output",
"output_file",
required=False,
type=click.Path(dir_okay=False),
help="Path to write the export config JSON template. Prints to stdout when omitted.",
)
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite output if it already exists.")
def export_migration_data_template(output_file: str | None, overwrite: bool) -> None:
template_json = json.dumps(_scripted_export_template(), indent=2, ensure_ascii=False) + "\n"
if output_file is None:
click.echo(template_json, nl=False)
return
path = Path(output_file)
if path.exists() and not overwrite:
raise click.ClickException(f"Output file already exists: {output_file}")
path.write_text(template_json)
click.echo(click.style(f"Output written to {output_file}", fg="green"))
@click.command("export-app-migration", help="Export workflow migration data to a versioned JSON package.")
@click.option(
"--input",
"input_file",
required=False,
type=click.Path(exists=True, dir_okay=False),
help="Path to export config JSON.",
)
@click.option(
"--output",
"output_file",
required=False,
type=click.Path(dir_okay=False),
help="Path to migration package JSON.",
)
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite output if it already exists.")
def export_migration_data(input_file: str | None, output_file: str | None, overwrite: bool) -> None:
try:
_require_options(("--input", input_file), ("--output", output_file))
assert input_file is not None
assert output_file is not None
raw_config = _load_json_object(input_file, "Export config")
selection = ExportConfigParser().parse(raw_config)
result = MigrationExportService().export(selection)
MigrationPackageService().save_package(result.package, output_file, overwrite=overwrite)
click.echo(click.style(f"Output written to {output_file}", fg="green"))
_render_report(result.report_items, context=_with_output_path(result.report_context, output_file))
except MigrationDataError as exc:
raise click.ClickException(str(exc)) from exc
@click.command("import-app-migration", help="Import a versioned migration data package.")
@click.option(
"--input",
"input_file",
required=False,
type=click.Path(exists=True, dir_okay=False),
help="Path to migration package JSON.",
)
@click.option("--target-tenant", default=None, help="Target tenant/workspace name. Overrides package metadata.")
@click.option("--operator-email", default=None, help="Operator account email in the target tenant.")
@click.option(
"--id-strategy",
default=None,
type=click.Choice(ID_STRATEGY_CHOICES),
help="Override package ID strategy.",
)
@click.option(
"--conflict-strategy",
default=None,
type=click.Choice(CONFLICT_STRATEGY_CHOICES),
help="Override package conflict strategy.",
)
@click.option(
"--create-app-api-token-on-import/--no-create-app-api-token-on-import",
default=None,
help="Override package app API token creation behavior.",
)
def import_migration_data(
input_file: str | None,
target_tenant: str | None,
operator_email: str | None,
id_strategy: str | None,
conflict_strategy: str | None,
create_app_api_token_on_import: bool | None,
) -> None:
try:
_require_options(("--input", input_file))
assert input_file is not None
package = MigrationPackageService().load_package(input_file)
result = MigrationImportService().import_package(
ImportRequest(
package=package,
cli_target_tenant=target_tenant,
operator_email=operator_email,
options_override=_build_options_override(
package.metadata.import_options,
id_strategy=id_strategy,
conflict_strategy=conflict_strategy,
create_app_api_token_on_import=create_app_api_token_on_import,
),
)
)
_render_report(result.report_items, context=result.report_context)
except MigrationDataError as exc:
raise click.ClickException(str(exc)) from exc
def parse_index_selection(raw: str, values: list[str]) -> list[str]:
normalized = raw.strip().lower()
if normalized == "all":
return values
selected: list[str] = []
for part in raw.split(","):
stripped = part.strip()
if not stripped:
continue
try:
index = int(stripped)
except ValueError as exc:
raise click.ClickException(f"Selection must be 'all' or comma-separated numbers: {raw}") from exc
if index < 1 or index > len(values):
raise click.ClickException(f"Selection index out of range: {index}")
selected.append(values[index - 1])
return list(dict.fromkeys(selected))
def _print_wizard_step(title: str) -> None:
click.echo("")
click.echo(f"==== {title} ====")
def _print_wizard_substep(title: str) -> None:
click.echo("")
click.echo(f"-- {title} --")
@click.command("app-migration-wizard", help="Interactively export workflow migration data.")
def migration_data_wizard() -> None:
try:
tenant = _prompt_source_tenant()
apps = _eligible_apps_for_tenant(tenant.id)
app_ids = _prompt_app_ids(apps)
_print_wizard_step("Referenced Tools")
include_referenced_tools = click.confirm(
"Automatically export tools referenced by selected apps? [y/n, default: y]",
default=True,
show_default=False,
)
auto_tools = _discover_auto_tools([app for app in apps if app.id in set(app_ids)], include_referenced_tools)
auto_tools = _resolve_auto_tool_names(tenant.id, auto_tools)
_print_auto_tools(auto_tools)
additional_tools = _prompt_additional_tools(tenant.id, auto_tools)
include_secrets, create_tokens, id_strategy, conflict_strategy = _prompt_import_options()
_print_wizard_step("Output")
output_file, overwrite = _prompt_output_file()
selection = ExportConfigParser().parse(
{
"source_tenant": {"mode": "single", "id": tenant.id, "name": tenant.name},
"apps": {"ids": app_ids, "all": False},
"include_referenced_tools": include_referenced_tools,
"additional_tools": additional_tools,
"include_secrets": include_secrets,
"import_options": {
"create_app_api_token_on_import": create_tokens,
"id_strategy": id_strategy,
"conflict_strategy": conflict_strategy,
},
}
)
_confirm_wizard_summary(
tenant_name=tenant.name,
app_names=[app.name for app in apps if app.id in set(app_ids)],
auto_tools=auto_tools,
additional_tools=additional_tools,
manual_labels=_selected_tool_labels_for_tenant(tenant.id, additional_tools),
include_referenced_tools=include_referenced_tools,
include_secrets=include_secrets,
create_tokens=create_tokens,
id_strategy=id_strategy,
conflict_strategy=conflict_strategy,
output_file=output_file,
)
result = MigrationExportService().export(selection)
MigrationPackageService().save_package(result.package, output_file, overwrite=overwrite)
click.echo(click.style(f"Output written to {output_file}", fg="green"))
_print_wizard_step("Report")
_render_report(result.report_items, context=_with_output_path(result.report_context, output_file))
except MigrationDataError as exc:
raise click.ClickException(str(exc)) from exc
def _load_json_object(path: str, label: str) -> dict[str, Any]:
try:
with Path(path).open(encoding="utf-8") as file:
raw = json.load(file)
except json.JSONDecodeError as exc:
raise MigrationDataError(f"{label} JSON is invalid: {exc.msg}") from exc
if not isinstance(raw, dict):
raise MigrationDataError(f"{label} JSON must be an object.")
return raw
def _require_options(*options: tuple[str, object | None]) -> None:
missing_options = [name for name, value in options if value is None]
if missing_options:
raise click.UsageError(f"Missing option(s): {', '.join(missing_options)}.")
def _build_options_override(
package_options: ImportOptions,
*,
id_strategy: str | None,
conflict_strategy: str | None,
create_app_api_token_on_import: bool | None,
) -> ImportOptions | None:
if id_strategy is None and conflict_strategy is None and create_app_api_token_on_import is None:
return None
return ImportOptions.from_mapping(
{
"id_strategy": id_strategy or package_options.id_strategy,
"conflict_strategy": conflict_strategy or package_options.conflict_strategy,
"create_app_api_token_on_import": (
create_app_api_token_on_import
if create_app_api_token_on_import is not None
else package_options.create_app_api_token_on_import
),
}
)
def _prompt_source_tenant() -> Tenant:
tenants = list(db.session.scalars(sa.select(Tenant).order_by(Tenant.name.asc())).all())
if not tenants:
raise MigrationDataError("No tenants found.")
_print_wizard_step("Source Tenant")
click.echo("Source tenants:")
for index, tenant in enumerate(tenants, 1):
click.echo(f"{index}. {tenant.name} ({tenant.id})")
tenant_index = click.prompt("Select one source tenant by number", type=int, default=1, show_default=True)
if tenant_index < 1 or tenant_index > len(tenants):
raise click.ClickException(f"Selection index out of range: {tenant_index}")
return tenants[tenant_index - 1]
def _eligible_apps_for_tenant(tenant_id: str) -> list[App]:
return list(
db.session.scalars(
sa.select(App)
.where(App.tenant_id == tenant_id, App.mode.in_(SUPPORTED_WIZARD_APP_MODES))
.order_by(App.name.asc())
).all()
)
def _prompt_app_ids(apps: list[App]) -> list[str]:
if not apps:
raise MigrationDataError("No workflow or advanced-chat apps found for the selected tenant.")
_print_wizard_step("App Selection")
click.echo("Currently supported app types: workflow and chatflow.")
click.echo("Workflow/chatflow apps:")
for index, app in enumerate(apps, 1):
mode = app.mode.value if hasattr(app.mode, "value") else app.mode
click.echo(f"{index}. {app.name} [{mode}] ({app.id})")
app_ids = parse_index_selection(
click.prompt("Select apps by number, comma-separated numbers, or all", default="all"),
[app.id for app in apps],
)
selected_apps = [app for app in apps if app.id in set(app_ids)]
click.echo("Selected apps:")
for app in selected_apps:
click.echo(f"- {app.name} ({app.id})")
return app_ids
def _prompt_import_options() -> tuple[bool, bool, str, str]:
_print_wizard_step("Import Options")
_print_wizard_substep("Secrets")
click.echo("Secrets include workflow/app DSL secret values, custom API tool credentials,")
click.echo("and full MCP provider connection data such as server URL, headers, authentication, and tool list.")
click.echo("If you choose no, credentials are omitted or masked,")
click.echo("and MCP providers are exported as dependency metadata only.")
click.echo("Treat the output JSON as sensitive if you choose yes.")
include_secrets = click.confirm(
"Include secrets in output JSON? [y/n, default: n]",
default=False,
show_default=False,
)
_print_wizard_substep("App API Tokens")
click.echo("When enabled, import will create an app API token if the imported app has none,")
click.echo("or reuse an existing app API token if one already exists.")
create_tokens = click.confirm(
"Create or reuse app API tokens during import? [y/n, default: n]",
default=False,
show_default=False,
)
_print_wizard_substep("ID Strategy")
click.echo("ID strategy controls whether imported app and tool IDs preserve source IDs")
click.echo("or use target-generated IDs.")
click.echo("preserve-id: keep source IDs where the target service supports it.")
click.echo("generate-new-id: let the target environment generate new IDs and rewrite references via mapping.")
id_strategy = click.prompt(
"Import ID strategy. Enter one of: preserve-id, generate-new-id",
type=click.Choice(ID_STRATEGY_CHOICES),
default="preserve-id",
show_default=True,
)
_print_wizard_substep("Conflict Strategy")
click.echo("Conflict strategy controls what import does when a target resource already exists.")
click.echo("fail: stop at the first conflict; previously committed resources are not rolled back.")
click.echo("skip: keep the existing target resource and skip importing that resource.")
click.echo("update: update the existing target resource in place.")
conflict_strategy = click.prompt(
"Import conflict strategy. Enter one of: fail, skip, update",
type=click.Choice(CONFLICT_STRATEGY_CHOICES),
default="update",
show_default=True,
)
return include_secrets, create_tokens, id_strategy, conflict_strategy
def _discover_auto_tools(apps: list[App], include_referenced_tools: bool) -> WizardToolMap:
auto_tools: WizardToolMap = {"api_tools": {}, "workflow_tools": {}, "mcp_tools": {}}
if not include_referenced_tools:
return auto_tools
discovery_service = DependencyDiscoveryService()
for app in apps:
dsl_content = AppDslService.export_dsl(app_model=app, include_secret=False)
raw_dsl = yaml.safe_load(dsl_content) if dsl_content else {}
dsl = raw_dsl if isinstance(raw_dsl, dict) else {}
for dependency in discovery_service.discover_from_dsl(dsl):
if dependency.kind == DependencyKind.API_TOOL:
auto_tools["api_tools"][dependency.provider_name or dependency.provider_id] = dependency.provider_id
elif dependency.kind == DependencyKind.WORKFLOW_TOOL:
auto_tools["workflow_tools"][dependency.provider_name or dependency.provider_id] = (
dependency.provider_id
)
elif dependency.kind == DependencyKind.MCP_TOOL:
auto_tools["mcp_tools"][dependency.provider_name or dependency.provider_id] = dependency.provider_id
return auto_tools
def _resolve_auto_tool_names(tenant_id: str, auto_tools: WizardToolMap) -> WizardToolMap:
return {
"api_tools": _resolve_api_tool_names(tenant_id, auto_tools["api_tools"]),
"workflow_tools": _resolve_workflow_tool_names(tenant_id, auto_tools["workflow_tools"]),
"mcp_tools": _resolve_mcp_tool_names(tenant_id, auto_tools["mcp_tools"]),
}
def _resolve_api_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
resolved: dict[str, str | None] = {}
for name, identifier in tools.items():
predicates = [ApiToolProvider.name == name]
if _is_uuid_string(identifier):
predicates.append(ApiToolProvider.id == identifier)
provider = db.session.scalar(
sa.select(ApiToolProvider).where(
ApiToolProvider.tenant_id == tenant_id,
sa.or_(*predicates),
)
)
resolved[provider.name if provider else name] = provider.id if provider else identifier
return resolved
def _resolve_workflow_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
resolved: dict[str, str | None] = {}
for name, identifier in tools.items():
predicates = [WorkflowToolProvider.name == name]
if _is_uuid_string(identifier):
predicates.append(WorkflowToolProvider.id == identifier)
provider = db.session.scalar(
sa.select(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id,
sa.or_(*predicates),
)
)
resolved[provider.name if provider else name] = provider.id if provider else identifier
return resolved
def _resolve_mcp_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
resolved: dict[str, str | None] = {}
for name, identifier in tools.items():
predicates = [MCPToolProvider.name == name]
if identifier:
predicates.append(MCPToolProvider.server_identifier == identifier)
if _is_uuid_string(identifier):
predicates.append(MCPToolProvider.id == identifier)
provider = db.session.scalar(
sa.select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id,
sa.or_(*predicates),
)
)
resolved[provider.name if provider else name] = provider.id if provider else identifier
return resolved
def _is_uuid_string(value: str | None) -> bool:
if not value:
return False
try:
UUID(value)
except ValueError:
return False
return True
def _print_auto_tools(auto_tools: WizardToolMap) -> None:
_print_wizard_step("Automatically Discovered Tools")
click.echo("Automatically discovered tools:")
_print_auto_tool_category("Custom API tools", auto_tools["api_tools"])
_print_auto_tool_category("Workflow tools", auto_tools["workflow_tools"])
_print_auto_tool_category("MCP tools", auto_tools["mcp_tools"])
def _print_auto_tool_category(label: str, values: dict[str, str | None]) -> None:
click.echo(label)
if not values:
click.echo("- none")
return
for name, identifier in sorted(values.items()):
click.echo(f"- {_format_tool_name_id(name, identifier)}")
def _prompt_additional_tools(tenant_id: str, auto_tools: WizardToolMap) -> WizardToolSelection:
selections: WizardToolSelection = {"api_tools": [], "workflow_tools": [], "mcp_tools": []}
_print_wizard_step("Additional Tools")
if not click.confirm(
"Export additional tools manually? [y/n, default: n]",
default=False,
show_default=False,
):
_print_final_tool_selection(auto_tools, selections, {})
return selections
manual_labels: dict[str, str] = {}
api_tool_options = [
(tool.name, tool.name, tool.id)
for tool in db.session.scalars(
sa.select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).order_by(ApiToolProvider.name)
).all()
]
selections["api_tools"] = _prompt_tool_category(
"Custom API tools",
api_tool_options,
auto_tools=auto_tools["api_tools"],
)
manual_labels.update(_selected_tool_labels(api_tool_options, selections["api_tools"]))
workflow_tool_options = [
(tool.id, tool.name, tool.id)
for tool in db.session.scalars(
sa.select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id)
.order_by(WorkflowToolProvider.name)
).all()
]
selections["workflow_tools"] = _prompt_tool_category(
"Workflow tools",
workflow_tool_options,
auto_tools=auto_tools["workflow_tools"],
)
manual_labels.update(_selected_tool_labels(workflow_tool_options, selections["workflow_tools"]))
mcp_tool_options = [
(tool.id, tool.name, tool.server_identifier)
for tool in db.session.scalars(
sa.select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
).all()
]
selections["mcp_tools"] = _prompt_tool_category(
"MCP tools",
mcp_tool_options,
auto_tools=auto_tools["mcp_tools"],
)
manual_labels.update(_selected_tool_labels(mcp_tool_options, selections["mcp_tools"]))
_print_final_tool_selection(auto_tools, selections, manual_labels)
return selections
def _selected_tool_labels_for_tenant(tenant_id: str, selected_tools: WizardToolSelection) -> dict[str, str]:
labels: dict[str, str] = {}
if selected_tools["api_tools"]:
labels.update(
_selected_tool_labels(
[
(tool.name, tool.name, tool.id)
for tool in db.session.scalars(
sa.select(ApiToolProvider)
.where(ApiToolProvider.tenant_id == tenant_id)
.order_by(ApiToolProvider.name)
).all()
],
selected_tools["api_tools"],
)
)
if selected_tools["workflow_tools"]:
labels.update(
_selected_tool_labels(
[
(tool.id, tool.name, tool.id)
for tool in db.session.scalars(
sa.select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id)
.order_by(WorkflowToolProvider.name)
).all()
],
selected_tools["workflow_tools"],
)
)
if selected_tools["mcp_tools"]:
labels.update(
_selected_tool_labels(
[
(tool.id, tool.name, tool.server_identifier)
for tool in db.session.scalars(
sa.select(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id)
.order_by(MCPToolProvider.name)
).all()
],
selected_tools["mcp_tools"],
)
)
return labels
def _selected_tool_labels(options: list[tuple[str, str, str]], selected_values: list[str]) -> dict[str, str]:
selected = set(selected_values)
return {value: _format_tool_name_id(name, detail) for value, name, detail in options if value in selected}
def _prompt_tool_category(
label: str,
options: list[tuple[str, str, str]],
*,
auto_tools: dict[str, str | None],
) -> list[str]:
if not options:
click.echo(f"{label}: none")
return []
_print_wizard_step(label)
for index, (value, name, detail) in enumerate(options, 1):
marker = "[auto]" if _is_auto_tool(value, name, detail, auto_tools) else "[ ]"
click.echo(f"{index}. {marker} {name} ({detail})")
raw = click.prompt(
f"Select {label.lower()} by number, comma-separated numbers, all, or empty",
default="",
show_default=cast(Any, "empty"),
)
if not raw.strip():
return []
return parse_index_selection(raw, [value for value, _, _ in options])
def _is_auto_tool(value: str, name: str, detail: str, auto_tools: dict[str, str | None]) -> bool:
return name in auto_tools or value in auto_tools or value in auto_tools.values() or detail in auto_tools.values()
def _print_final_tool_selection(
auto_tools: WizardToolMap,
additional_tools: WizardToolSelection,
manual_labels: dict[str, str],
) -> None:
_print_wizard_step("Final Tool Selection")
_print_tool_selection_body(auto_tools, additional_tools, manual_labels)
def _print_tool_selection_body(
auto_tools: WizardToolMap,
additional_tools: WizardToolSelection,
manual_labels: dict[str, str],
) -> None:
click.echo("Final tools to export:")
_print_final_tool_category(
"Custom API tools",
auto_tools["api_tools"],
additional_tools["api_tools"],
manual_labels,
)
_print_final_tool_category(
"Workflow tools",
auto_tools["workflow_tools"],
additional_tools["workflow_tools"],
manual_labels,
)
_print_final_tool_category("MCP tools", auto_tools["mcp_tools"], additional_tools["mcp_tools"], manual_labels)
def _print_final_tool_category(
label: str,
auto_tools: dict[str, str | None],
manual_values: list[str],
manual_labels: dict[str, str],
) -> None:
click.echo(label)
lines = [f"- [auto] {_format_tool_name_id(name, identifier)}" for name, identifier in sorted(auto_tools.items())]
auto_identifiers = {identifier for identifier in auto_tools.values() if identifier}
lines.extend(
f"- [manual] {manual_labels.get(value, value)}"
for value in manual_values
if value not in auto_tools and value not in auto_identifiers
)
if not lines:
click.echo("- none")
return
for line in lines:
click.echo(line)
def _format_tool_name_id(name: str, identifier: str | None) -> str:
if identifier and identifier != name:
return f"{name}: {identifier}"
return name
def _confirm_wizard_summary(
*,
tenant_name: str,
app_names: list[str],
auto_tools: WizardToolMap,
additional_tools: WizardToolSelection,
manual_labels: dict[str, str],
include_referenced_tools: bool,
include_secrets: bool,
create_tokens: bool,
id_strategy: str,
conflict_strategy: str,
output_file: str,
) -> None:
_print_wizard_step("Summary")
click.echo("Migration export summary:")
click.echo(f"source tenant: {tenant_name}")
click.echo(f"selected apps: {len(app_names)}")
for app_name in app_names:
click.echo(f"- {app_name}")
click.echo(f"auto referenced tools: {str(include_referenced_tools).lower()}")
_print_tool_selection_body(auto_tools, additional_tools, manual_labels)
click.echo(f"include secrets: {str(include_secrets).lower()}")
click.echo(f"create app api token on import: {str(create_tokens).lower()}")
click.echo(f"id strategy: {id_strategy}")
click.echo(f"conflict strategy: {conflict_strategy}")
click.echo(f"output path: {output_file}")
if not click.confirm("Write migration package? [y/n, default: y]", default=True, show_default=False):
raise click.Abort()
def _prompt_output_file() -> tuple[str, bool]:
default_output = f"migration-data-{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
output_file = click.prompt("Output path", default=default_output, show_default=True)
if output_file.lower() in {"y", "yes", "n", "no"}:
raise click.ClickException("Output path must be a file path. Press Enter to use the default path.")
overwrite = False
if Path(output_file).exists():
overwrite = click.confirm(
"Output file exists. Overwrite? [y/n, default: n]",
default=False,
show_default=False,
)
if not overwrite:
raise click.ClickException(f"Output file already exists: {output_file}")
return output_file, overwrite
def _with_output_path(context: ReportContext | None, output_path: str) -> ReportContext:
if context is None:
return ReportContext(output_path=output_path)
return ReportContext(
output_path=output_path,
source_scope=context.source_scope,
selected_app_count=context.selected_app_count,
include_secrets=context.include_secrets,
target_tenant=context.target_tenant,
operator_email=context.operator_email,
app_api_tokens_created=context.app_api_tokens_created,
app_api_tokens_reused=context.app_api_tokens_reused,
id_mapping_count=context.id_mapping_count,
id_mappings=context.id_mappings,
)
def _render_report(report_items: list[ResourceReportItem], *, context: ReportContext | None = None) -> None:
for line in MigrationReportService().render(report_items, context=context):
click.echo(line)

View File

@ -30,7 +30,7 @@ def vdb_migrate(scope: str):
def migrate_annotation_vector_database():
"""
Migrate annotation datas to target vector database .
Migrate annotation data to target vector database.
"""
click.echo(click.style("Starting annotation data migration.", fg="green"))
create_count = 0
@ -140,7 +140,7 @@ def migrate_annotation_vector_database():
def migrate_knowledge_vector_database():
"""
Migrate vector database datas to target vector database .
Migrate vector database data to target vector database.
"""
click.echo(click.style("Starting vector database migration.", fg="green"))
create_count = 0

View File

@ -29,6 +29,7 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError
@override
def __call__(self) -> dict[str, Any]:
current_state = self.current_state
remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")

View File

@ -81,4 +81,15 @@ default_app_templates: Mapping[AppMode, Mapping] = {
},
},
},
# agent default mode (new Agent App type). The runtime model / prompt / tools
# come from the bound Agent Soul snapshot, so no model_config is seeded in the
# template; create_app still creates a model-less app_model_config row to hold
# app-level presentation features (opener, follow-up, citations, ...).
AppMode.AGENT: {
"app": {
"mode": AppMode.AGENT,
"enable_site": True,
"enable_api": True,
},
},
}

View File

@ -6,10 +6,11 @@ These helpers keep that translation centralized so models registered through
`register_schema_models` emit resolvable Swagger 2.0 references.
"""
from collections.abc import Mapping
from collections.abc import Iterable, Mapping
from enum import StrEnum
from typing import Any, Literal, NotRequired, TypedDict
from typing import Any, Literal, NotRequired, Protocol, TypedDict
from flask import request
from flask_restx import Namespace
from pydantic import BaseModel, TypeAdapter
@ -36,6 +37,12 @@ QueryParamDoc = TypedDict(
)
class QueryArgs(Protocol):
def to_dict(self, flat: bool = True) -> dict[str, str]: ...
def getlist(self, key: str) -> list[str]: ...
def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None:
"""Register a JSON schema and promote any nested Pydantic `$defs`."""
@ -167,6 +174,58 @@ def query_params_from_model(model: type[BaseModel]) -> dict[str, QueryParamDoc]:
return params
def query_params_from_request[ModelT: BaseModel](
model: type[ModelT],
*,
list_fields: Iterable[str] = (),
args: QueryArgs | None = None,
use_defaults_for_malformed_ints: bool = False,
) -> ModelT:
"""Validate query args with Pydantic while preserving Flask query parsing behavior.
Repeated params need explicit ``getlist()`` handling because Werkzeug's
``to_dict()`` keeps only one value. For malformed scalar integers, Flask's
For endpoints migrated from ``request.args.get(..., type=int, default=...)``,
set ``use_defaults_for_malformed_ints`` to preserve Flask's fallback to
defaults for malformed optional integer params.
"""
query_args = args or request.args
params: dict[str, Any] = query_args.to_dict()
for field_name in list_fields:
params[field_name] = query_args.getlist(field_name)
if use_defaults_for_malformed_ints:
_drop_malformed_defaulted_integer_params(model, params)
return model.model_validate(params)
def _drop_malformed_defaulted_integer_params(model: type[BaseModel], params: dict[str, Any]) -> None:
properties = model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0).get("properties", {})
if not isinstance(properties, Mapping):
return
for name, value in list(params.items()):
if not isinstance(value, str):
continue
field = model.model_fields.get(name)
if field is None or field.is_required():
continue
property_schema = properties.get(name)
if not isinstance(property_schema, Mapping):
continue
if _nullable_property_schema(property_schema).get("type") != "integer":
continue
try:
int(value)
except ValueError:
params.pop(name)
def _query_param_from_property(property_schema: Mapping[str, Any], *, required: bool) -> QueryParamDoc:
param_schema = _nullable_property_schema(property_schema)
param_doc: QueryParamDoc = {"in": "query", "required": required}
@ -239,6 +298,7 @@ __all__ = [
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
"get_or_create_model",
"query_params_from_model",
"query_params_from_request",
"register_enum_models",
"register_response_schema_model",
"register_response_schema_models",

View File

@ -51,6 +51,8 @@ from .agent import roster as agent_roster
from .app import (
advanced_prompt_template,
agent,
agent_app_access,
agent_app_feature,
annotation,
app,
audio,
@ -146,6 +148,8 @@ __all__ = [
"activate",
"advanced_prompt_template",
"agent",
"agent_app_access",
"agent_app_feature",
"agent_composer",
"agent_providers",
"agent_roster",

View File

@ -1,53 +1,91 @@
from flask_restx import Resource
from controllers.common.schema import register_schema_models
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from libs.login import current_account_with_tenant, login_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from fields.agent_fields import (
AgentAppComposerResponse,
AgentComposerCandidatesResponse,
AgentComposerImpactResponse,
AgentComposerValidateResponse,
WorkflowAgentComposerResponse,
)
from libs.helper import dump_response
from libs.login import login_required
from models.model import App, AppMode
from services.agent.composer_service import AgentComposerService
from services.agent.composer_validator import ComposerConfigValidator
from services.entities.agent_entities import ComposerSavePayload
register_schema_models(console_ns, ComposerSavePayload)
register_response_schema_models(
console_ns,
AgentAppComposerResponse,
AgentComposerCandidatesResponse,
AgentComposerImpactResponse,
AgentComposerValidateResponse,
WorkflowAgentComposerResponse,
)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer")
class WorkflowAgentComposerApi(Resource):
@console_ns.response(
200, "Workflow agent composer state", console_ns.models[WorkflowAgentComposerResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def get(self, app_model: App, node_id: str):
_, tenant_id = current_account_with_tenant()
return AgentComposerService.load_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App, node_id: str):
return dump_response(
WorkflowAgentComposerResponse,
AgentComposerService.load_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
),
)
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(
200, "Workflow agent composer saved", console_ns.models[WorkflowAgentComposerResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def put(self, app_model: App, node_id: str):
account, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def put(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
return AgentComposerService.save_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account.id,
payload=payload,
return dump_response(
WorkflowAgentComposerResponse,
AgentComposerService.save_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account_id,
payload=payload,
),
)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/validate")
class WorkflowAgentComposerValidateApi(Resource):
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(
200, "Workflow agent composer validation result", console_ns.models[AgentComposerValidateResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@ -55,84 +93,116 @@ class WorkflowAgentComposerValidateApi(Resource):
def post(self, app_model: App, node_id: str):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
ComposerConfigValidator.validate_save_payload(payload)
return {"result": "success", "errors": []}
return dump_response(AgentComposerValidateResponse, {"result": "success", "errors": []})
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/candidates")
class WorkflowAgentComposerCandidatesApi(Resource):
@console_ns.response(
200, "Workflow agent composer candidates", console_ns.models[AgentComposerCandidatesResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def get(self, app_model: App, node_id: str):
return AgentComposerService.get_workflow_candidates(app_id=app_model.id)
return dump_response(
AgentComposerCandidatesResponse,
AgentComposerService.get_workflow_candidates(app_id=app_model.id),
)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/impact")
class WorkflowAgentComposerImpactApi(Resource):
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(200, "Workflow agent composer impact", console_ns.models[AgentComposerImpactResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def post(self, app_model: App, node_id: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str, app_model: App, node_id: str):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None
if not current_snapshot_id:
return {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
return AgentComposerService.calculate_impact(tenant_id=tenant_id, current_snapshot_id=current_snapshot_id)
return dump_response(
AgentComposerImpactResponse, {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
)
return dump_response(
AgentComposerImpactResponse,
AgentComposerService.calculate_impact(tenant_id=tenant_id, current_snapshot_id=current_snapshot_id),
)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/save-to-roster")
class WorkflowAgentComposerSaveToRosterApi(Resource):
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(
200, "Workflow agent composer saved to roster", console_ns.models[WorkflowAgentComposerResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def post(self, app_model: App, node_id: str):
account, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
return AgentComposerService.save_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account.id,
payload=payload,
return dump_response(
WorkflowAgentComposerResponse,
AgentComposerService.save_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account_id,
payload=payload,
),
)
@console_ns.route("/apps/<uuid:app_id>/agent-composer")
class AgentAppComposerApi(Resource):
@console_ns.response(200, "Agent app composer state", console_ns.models[AgentAppComposerResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model()
def get(self, app_model: App):
_, tenant_id = current_account_with_tenant()
return AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id)
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App):
return dump_response(
AgentAppComposerResponse,
AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id),
)
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(200, "Agent app composer saved", console_ns.models[AgentAppComposerResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model()
def put(self, app_model: App):
account, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def put(self, tenant_id: str, account_id: str, app_model: App):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
return AgentComposerService.save_agent_app_composer(
tenant_id=tenant_id,
app_id=app_model.id,
account_id=account.id,
payload=payload,
return dump_response(
AgentAppComposerResponse,
AgentComposerService.save_agent_app_composer(
tenant_id=tenant_id,
app_id=app_model.id,
account_id=account_id,
payload=payload,
),
)
@console_ns.route("/apps/<uuid:app_id>/agent-composer/validate")
class AgentAppComposerValidateApi(Resource):
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(
200, "Agent app composer validation result", console_ns.models[AgentComposerValidateResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@ -140,14 +210,20 @@ class AgentAppComposerValidateApi(Resource):
def post(self, app_model: App):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
ComposerConfigValidator.validate_save_payload(payload)
return {"result": "success", "errors": []}
return dump_response(AgentComposerValidateResponse, {"result": "success", "errors": []})
@console_ns.route("/apps/<uuid:app_id>/agent-composer/candidates")
class AgentAppComposerCandidatesApi(Resource):
@console_ns.response(
200, "Agent app composer candidates", console_ns.models[AgentComposerCandidatesResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@get_app_model()
def get(self, app_model: App):
return AgentComposerService.get_agent_app_candidates(app_id=app_model.id)
return dump_response(
AgentComposerCandidatesResponse,
AgentComposerService.get_agent_app_candidates(app_id=app_model.id),
)

View File

@ -4,11 +4,25 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from fields.agent_fields import (
AgentConfigSnapshotDetailResponse,
AgentConfigSnapshotListResponse,
AgentInviteOptionsResponse,
AgentRosterListResponse,
AgentRosterResponse,
)
from libs.helper import dump_response
from libs.login import login_required
from services.agent.roster_service import AgentRosterService
from services.entities.agent_entities import RosterAgentCreatePayload, RosterAgentUpdatePayload, RosterListQuery
@ -29,6 +43,14 @@ register_schema_models(
RosterAgentUpdatePayload,
RosterListQuery,
)
register_response_schema_models(
console_ns,
AgentConfigSnapshotDetailResponse,
AgentConfigSnapshotListResponse,
AgentInviteOptionsResponse,
AgentRosterListResponse,
AgentRosterResponse,
)
def _agent_roster_service() -> AgentRosterService:
@ -37,96 +59,130 @@ def _agent_roster_service() -> AgentRosterService:
@console_ns.route("/agents")
class AgentRosterListApi(Resource):
@console_ns.doc(params=query_params_from_model(RosterListQuery))
@console_ns.response(200, "Agent roster list", console_ns.models[AgentRosterListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
query = RosterListQuery.model_validate(request.args.to_dict(flat=True))
return _agent_roster_service().list_roster_agents(
tenant_id=tenant_id, page=query.page, limit=query.limit, keyword=query.keyword
return dump_response(
AgentRosterListResponse,
_agent_roster_service().list_roster_agents(
tenant_id=tenant_id, page=query.page, limit=query.limit, keyword=query.keyword
),
)
@console_ns.expect(console_ns.models[RosterAgentCreatePayload.__name__])
@console_ns.response(201, "Agent created", console_ns.models[AgentRosterResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
account, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, account_id: str):
payload = RosterAgentCreatePayload.model_validate(console_ns.payload or {})
service = _agent_roster_service()
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account.id, payload=payload)
return service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id), 201
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account_id, payload=payload)
return dump_response(
AgentRosterResponse,
service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id),
), 201
@console_ns.route("/agents/invite-options")
class AgentInviteOptionsApi(Resource):
@console_ns.doc(params=query_params_from_model(AgentInviteOptionsQuery))
@console_ns.response(200, "Agent invite options", console_ns.models[AgentInviteOptionsResponse.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
query = AgentInviteOptionsQuery.model_validate(request.args.to_dict(flat=True))
return _agent_roster_service().list_invite_options(
tenant_id=tenant_id,
page=query.page,
limit=query.limit,
keyword=query.keyword,
app_id=query.app_id,
return dump_response(
AgentInviteOptionsResponse,
_agent_roster_service().list_invite_options(
tenant_id=tenant_id,
page=query.page,
limit=query.limit,
keyword=query.keyword,
app_id=query.app_id,
),
)
@console_ns.route("/agents/<uuid:agent_id>")
class AgentRosterDetailApi(Resource):
@console_ns.response(200, "Agent detail", console_ns.models[AgentRosterResponse.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, agent_id: UUID):
_, tenant_id = current_account_with_tenant()
return _agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id))
@console_ns.expect(console_ns.models[RosterAgentUpdatePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def patch(self, agent_id: UUID):
account, tenant_id = current_account_with_tenant()
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
return _agent_roster_service().update_roster_agent(
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id, payload=payload
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID):
return dump_response(
AgentRosterResponse,
_agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id)),
)
@console_ns.expect(console_ns.models[RosterAgentUpdatePayload.__name__])
@console_ns.response(200, "Agent updated", console_ns.models[AgentRosterResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, agent_id: UUID):
account, tenant_id = current_account_with_tenant()
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
@with_current_user_id
@with_current_tenant_id
def patch(self, tenant_id: str, account_id: str, agent_id: UUID):
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
return dump_response(
AgentRosterResponse,
_agent_roster_service().update_roster_agent(
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id, payload=payload
),
)
@console_ns.response(204, "Agent archived")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user_id
@with_current_tenant_id
def delete(self, tenant_id: str, account_id: str, agent_id: UUID):
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id)
return "", 204
@console_ns.route("/agents/<uuid:agent_id>/versions")
class AgentRosterVersionsApi(Resource):
@console_ns.response(200, "Agent versions", console_ns.models[AgentConfigSnapshotListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, agent_id: UUID):
_, tenant_id = current_account_with_tenant()
return {"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))}
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID):
return dump_response(
AgentConfigSnapshotListResponse,
{"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))},
)
@console_ns.route("/agents/<uuid:agent_id>/versions/<uuid:version_id>")
class AgentRosterVersionDetailApi(Resource):
@console_ns.response(200, "Agent version detail", console_ns.models[AgentConfigSnapshotDetailResponse.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, agent_id: UUID, version_id: UUID):
_, tenant_id = current_account_with_tenant()
return _agent_roster_service().get_agent_version_detail(
tenant_id=tenant_id,
agent_id=str(agent_id),
version_id=str(version_id),
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID, version_id: UUID):
return dump_response(
AgentConfigSnapshotDetailResponse,
_agent_roster_service().get_agent_version_detail(
tenant_id=tenant_id,
agent_id=str(agent_id),
version_id=str(version_id),
),
)

View File

@ -0,0 +1,59 @@
"""Agent App access & sharing endpoints (read-only workflow references).
An Agent App is backed by a roster Agent that workflow Agent nodes may also
reference. This exposes the read-only "Workflow access" surface from the PRD:
which workflow apps use this Agent, without leaking the workflows' internals.
"""
from flask_restx import Resource
from pydantic import Field
from controllers.common.schema import register_response_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.login import login_required
from models.model import App, AppMode
from services.agent.roster_service import AgentRosterService
class AgentReferencingWorkflowResponse(ResponseModel):
app_id: str
app_name: str
app_mode: str
workflow_id: str
node_ids: list[str] = Field(default_factory=list)
class AgentReferencingWorkflowsResponse(ResponseModel):
data: list[AgentReferencingWorkflowResponse] = Field(default_factory=list)
register_response_schema_models(console_ns, AgentReferencingWorkflowsResponse)
@console_ns.route("/apps/<uuid:app_id>/agent-referencing-workflows")
class AgentAppReferencingWorkflowsResource(Resource):
@console_ns.doc("list_agent_app_referencing_workflows")
@console_ns.doc(description="List workflow apps that reference this Agent App's bound Agent (read-only)")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(
200,
"Referencing workflows listed successfully",
console_ns.models[AgentReferencingWorkflowsResponse.__name__],
)
@console_ns.response(404, "App not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App):
workflows = AgentRosterService(db.session).list_workflows_referencing_app_agent(
tenant_id=tenant_id, app_id=app_model.id
)
return AgentReferencingWorkflowsResponse(
data=[AgentReferencingWorkflowResponse.model_validate(workflow) for workflow in workflows]
).model_dump(mode="json")

View File

@ -0,0 +1,93 @@
"""Agent App presentation-feature configuration endpoint.
The new Agent App type keeps model / prompt / tools in its bound Agent Soul, so
the legacy ``/model-config`` surface (which writes model, prompt and agent tool
config) is the wrong place to configure its app-level presentation features.
This endpoint exposes only the PRD "Misc Legacy" feature subset — conversation
opener, follow-up suggestions, citations, content moderation and speech — and
persists them onto the app's ``app_model_config`` without touching anything the
Soul owns.
"""
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from events.app_event import app_model_config_was_updated
from libs.helper import dump_response
from libs.login import login_required
from models import Account
from models.agent_config_entities import (
AgentFeatureToggleConfig,
AgentSensitiveWordAvoidanceFeatureConfig,
AgentSuggestedQuestionsAfterAnswerFeatureConfig,
AgentTextToSpeechFeatureConfig,
)
from models.model import App, AppMode
from services.agent_app_feature_service import AgentAppFeatureConfigService
class AgentAppFeaturesPayload(BaseModel):
"""Presentation features configurable on an Agent App.
All fields are optional; an omitted field is reset to its disabled/empty
default (the config form sends the full desired feature state on save).
"""
opening_statement: str | None = Field(default=None, description="Conversation opener shown before the first turn")
suggested_questions: list[str] | None = Field(
default=None, description="Preset questions shown alongside the opener"
)
suggested_questions_after_answer: AgentSuggestedQuestionsAfterAnswerFeatureConfig | None = Field(
default=None, description="Follow-up suggestions config, e.g. {'enabled': true}"
)
speech_to_text: AgentFeatureToggleConfig | None = Field(default=None, description="Speech-to-text config")
text_to_speech: AgentTextToSpeechFeatureConfig | None = Field(default=None, description="Text-to-speech config")
retriever_resource: AgentFeatureToggleConfig | None = Field(
default=None, description="Citations / attributions config, e.g. {'enabled': true}"
)
sensitive_word_avoidance: AgentSensitiveWordAvoidanceFeatureConfig | None = Field(
default=None, description="Content moderation config"
)
register_schema_models(console_ns, AgentAppFeaturesPayload)
register_response_schema_models(console_ns, SimpleResultResponse)
@console_ns.route("/apps/<uuid:app_id>/agent-features")
class AgentAppFeatureConfigResource(Resource):
@console_ns.doc("update_agent_app_features")
@console_ns.doc(description="Update an Agent App's presentation features (opener, follow-up, citations, ...)")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AgentAppFeaturesPayload.__name__])
@console_ns.response(200, "Features updated successfully", console_ns.models[SimpleResultResponse.__name__])
@console_ns.response(400, "Invalid configuration")
@console_ns.response(404, "App not found")
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
@with_current_user
def post(self, current_user: Account, app_model: App):
args = AgentAppFeaturesPayload.model_validate(console_ns.payload or {})
new_app_model_config = AgentAppFeatureConfigService.update_features(
app_model=app_model,
account=current_user,
config=args.model_dump(exclude_none=True),
)
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
return dump_response(SimpleResultResponse, {"result": "success"})

View File

@ -25,6 +25,9 @@ from controllers.console.wraps import (
enterprise_license_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
with_current_user_id,
)
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
@ -34,8 +37,8 @@ from extensions.ext_database import db
from fields.base import ResponseModel
from graphon.enums import WorkflowExecutionStatus
from libs.helper import build_icon_url, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from libs.login import login_required
from models import Account, App, DatasetPermissionEnum, Workflow
from models.model import IconType
from services.app_dsl_service import AppDslService
from services.app_service import AppListParams, AppService, CreateAppParams
@ -55,7 +58,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"]
register_enum_models(console_ns, IconType)
@ -66,7 +69,7 @@ _TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
default="all", description="App mode filter"
)
name: str | None = Field(default=None, description="Filter by app name")
@ -115,7 +118,9 @@ def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str,
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
mode: Literal["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"] = Field(
..., description="App mode"
)
icon_type: IconType | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@ -393,6 +398,8 @@ class AppDetailWithSite(AppDetail):
max_active_requests: int | None = None
deleted_tools: list[DeletedTool] = Field(default_factory=list)
site: Site | None = None
# For Agent App type: the roster Agent backing this app (None otherwise).
bound_agent_id: str | None = None
@computed_field(return_type=str | None) # type: ignore
@property
@ -467,10 +474,11 @@ class AppListApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
@with_session(write=False)
@with_current_user_id
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user_id: str, session: Session):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
params = AppListParams(
page=args.page,
@ -483,7 +491,7 @@ class AppListApi(Resource):
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, params)
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params)
if not app_pagination:
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
return empty.model_dump(mode="json"), 200
@ -504,7 +512,7 @@ class AppListApi(Resource):
draft_trigger_app_ids: set[str] = set()
if workflow_capable_app_ids:
draft_workflows = (
db.session.execute(
session.execute(
select(Workflow).where(
Workflow.version == Workflow.VERSION_DRAFT,
Workflow.app_id.in_(workflow_capable_app_ids),
@ -543,9 +551,10 @@ class AppListApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
"""Create app"""
current_user, current_tenant_id = current_account_with_tenant()
args = CreateAppPayload.model_validate(console_ns.payload)
params = CreateAppParams(
name=args.name,
@ -648,11 +657,10 @@ class AppCopyApi(Resource):
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model: App):
@with_current_user
def post(self, current_user: Account, app_model: App):
"""Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = CopyAppPayload.model_validate(console_ns.payload or {})
with Session(db.engine, expire_on_commit=False) as session:
@ -731,7 +739,8 @@ class AppPublishToCreatorsPlatformApi(Resource):
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model: App):
@with_current_user_id
def post(self, current_user_id: str, app_model: App):
"""Publish app to Creators Platform"""
from configs import dify_config
from core.helper.creators import get_redirect_url, upload_dsl
@ -739,13 +748,11 @@ class AppPublishToCreatorsPlatformApi(Resource):
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
return {"error": "Creators Platform features are not enabled"}, 403
current_user, _ = current_account_with_tenant()
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
dsl_bytes = dsl_content.encode("utf-8")
claim_code = upload_dsl(dsl_bytes)
redirect_url = get_redirect_url(str(current_user.id), claim_code)
redirect_url = get_redirect_url(current_user_id, claim_code)
return {"redirect_url": redirect_url}

View File

@ -9,9 +9,11 @@ from controllers.console.wraps import (
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
with_current_user,
)
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from models.model import App
from services.app_dsl_service import AppDslService, Import
from services.enterprise.enterprise_service import EnterpriseService
@ -48,9 +50,9 @@ class AppImportApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
@with_current_user
def post(self, current_user: Account):
# Check user role first
current_user, _ = current_account_with_tenant()
args = AppImportPayload.model_validate(console_ns.payload)
# AppDslService performs internal commits for some creation paths, so use a plain
@ -97,10 +99,9 @@ class AppImportConfirmApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, import_id: str):
@with_current_user
def post(self, current_user: Account, import_id: str):
# Check user role first
current_user, _ = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
# Confirm import

View File

@ -4,7 +4,7 @@ from typing import Any, Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from controllers.common.fields import SimpleResultResponse
@ -19,7 +19,12 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user_id,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
@ -41,9 +46,24 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
def _resolve_debugger_chat_streaming(
*, app_mode: AppMode, response_mode: str, response_mode_provided: bool = True
) -> bool:
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
if app_mode != AppMode.AGENT:
return response_mode != "blocking"
if response_mode_provided and response_mode == "blocking":
raise BadRequest("Agent App only supports streaming response mode.")
return True
class BaseMessagePayload(BaseModel):
inputs: dict[str, Any]
model_config_data: dict[str, Any] = Field(..., alias="model_config")
# Agent Apps (AppMode.AGENT) derive their model + prompt from the bound Agent
# Soul, so no override ``model_config`` is sent; chat / agent-chat / completion
# debugging still pass it. Optional here, required in practice by those modes
# downstream when their config is built from args.
model_config_data: dict[str, Any] = Field(default_factory=dict, alias="model_config")
files: list[Any] | None = Field(default=None, description="Uploaded files")
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
retriever_from: str = Field(default="dev", description="Retriever source")
@ -131,14 +151,13 @@ class CompletionMessageStopApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model: App, task_id: str):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
@with_current_user_id
def post(self, current_user_id: str, app_model: App, task_id: str):
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.DEBUGGER,
user_id=current_user.id,
user_id=current_user_id,
app_mode=AppMode.value_of(app_model.mode),
)
@ -157,13 +176,20 @@ class ChatMessageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.AGENT])
@edit_permission_required
def post(self, app_model: App):
args_model = ChatMessagePayload.model_validate(console_ns.payload)
raw_payload = console_ns.payload or {}
args_model = ChatMessagePayload.model_validate(raw_payload)
args = args_model.model_dump(exclude_none=True, by_alias=True)
streaming = args_model.response_mode != "blocking"
streaming = _resolve_debugger_chat_streaming(
app_mode=AppMode.value_of(app_model.mode),
response_mode=args_model.response_mode,
response_mode_provided=isinstance(raw_payload, dict) and "response_mode" in raw_payload,
)
if AppMode.value_of(app_model.mode) == AppMode.AGENT:
args["response_mode"] = "streaming"
args["auto_generate_name"] = False
external_trace_id = get_external_trace_id(request)
@ -211,15 +237,14 @@ class ChatMessageStopApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model: App, task_id: str):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@with_current_user_id
def post(self, current_user_id: str, app_model: App, task_id: str):
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.DEBUGGER,
user_id=current_user.id,
user_id=current_user_id,
app_mode=AppMode.value_of(app_model.mode),
)

View File

@ -12,7 +12,12 @@ from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
@ -31,8 +36,9 @@ from fields.conversation_fields import (
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
)
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models.account import Account
from models.model import App, AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
@ -93,8 +99,8 @@ class CompletionConversationApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def get(self, app_model: App):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, app_model: App):
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True))
query = sa.select(Conversation).where(
@ -165,10 +171,11 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def get(self, app_model: App, conversation_id: UUID):
@with_current_user
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
conversation_id_str = str(conversation_id)
return ConversationMessageDetailResponse.model_validate(
_get_conversation(app_model, conversation_id_str), from_attributes=True
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
).model_dump(mode="json")
@console_ns.doc("delete_completion_conversation")
@ -182,8 +189,8 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def delete(self, app_model: App, conversation_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
conversation_id_str = str(conversation_id)
try:
@ -205,10 +212,10 @@ class ChatConversationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@edit_permission_required
def get(self, app_model: App):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, app_model: App):
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True))
subquery = (
@ -316,12 +323,13 @@ class ChatConversationDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@edit_permission_required
def get(self, app_model: App, conversation_id: UUID):
@with_current_user
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
conversation_id_str = str(conversation_id)
return ConversationDetailResponse.model_validate(
_get_conversation(app_model, conversation_id_str), from_attributes=True
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
).model_dump(mode="json")
@console_ns.doc("delete_chat_conversation")
@ -332,11 +340,11 @@ class ChatConversationDetailApi(Resource):
@console_ns.response(404, "Conversation not found")
@setup_required
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@account_initialization_required
@edit_permission_required
def delete(self, app_model: App, conversation_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
conversation_id_str = str(conversation_id)
try:
@ -347,8 +355,7 @@ class ChatConversationDetailApi(Resource):
return "", 204
def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
def _get_conversation(current_user: Account, app_model, conversation_id):
conversation = db.session.scalar(
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
)

View File

@ -2,6 +2,7 @@ from collections.abc import Sequence
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
@ -11,6 +12,7 @@ from controllers.console.app.error import (
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import with_session
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
from core.app.app_config.entities import ModelConfig
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
@ -19,7 +21,6 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
@ -158,7 +159,8 @@ class InstructionGenerateApi(Resource):
@login_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str):
@with_session(write=False)
def post(self, session: Session, current_tenant_id: str):
args = InstructionGeneratePayload.model_validate(console_ns.payload)
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
@ -168,10 +170,10 @@ class InstructionGenerateApi(Resource):
try:
# Generate from nothing for a workflow node
if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.get(App, args.flow_id)
app = session.get(App, args.flow_id)
if not app:
return {"error": f"app {args.flow_id} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)
workflow = WorkflowService().get_draft_workflow(app_model=app, session=session)
if not workflow:
return {"error": f"workflow {args.flow_id} not found"}, 400
nodes: Sequence = workflow.graph_dict["nodes"]

View File

@ -25,6 +25,7 @@ from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
@ -43,7 +44,8 @@ from fields.conversation_fields import (
from graphon.model_runtime.errors.invoke import InvokeError
from libs.helper import to_timestamp, uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import App, AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.errors.conversation import ConversationNotExistsError
@ -178,7 +180,7 @@ class ChatMessageListApi(Resource):
@login_required
@account_initialization_required
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@edit_permission_required
def get(self, app_model: App):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
@ -257,9 +259,8 @@ class MessageFeedbackApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, app_model: App):
args = MessageFeedbackPayload.model_validate(console_ns.payload)
message_id = str(args.message_id)
@ -336,9 +337,9 @@ class MessageSuggestedQuestionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model: App, message_id: UUID):
current_user, _ = current_account_with_tenant()
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@with_current_user
def get(self, current_user: Account, app_model: App, message_id: UUID):
message_id_str = str(message_id)
try:

View File

@ -8,14 +8,20 @@ from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.model import App, AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
@ -52,9 +58,10 @@ class ModelConfigResource(Resource):
@edit_permission_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model: App):
@with_current_user_id
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user_id: str, app_model: App):
"""Modify app model config"""
current_user, current_tenant_id = current_account_with_tenant()
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_tenant_id,
@ -64,8 +71,8 @@ class ModelConfigResource(Resource):
new_app_model_config = AppModelConfig(
app_id=app_model.id,
created_by=current_user.id,
updated_by=current_user.id,
created_by=current_user_id,
updated_by=current_user_id,
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
@ -90,7 +97,7 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user.id,
user_id=current_user_id,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_tenant_id,
@ -130,7 +137,7 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user.id,
user_id=current_user_id,
)
except Exception:
continue
@ -167,7 +174,7 @@ class ModelConfigResource(Resource):
db.session.flush()
app_model.app_model_config_id = new_app_model_config.id
app_model.updated_by = current_user.id
app_model.updated_by = current_user_id
app_model.updated_at = naive_utc_now()
db.session.commit()

View File

@ -14,12 +14,14 @@ from controllers.console.wraps import (
edit_permission_required,
is_admin_or_owner_required,
setup_required,
with_current_user,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Site
from models.account import Account
from models.model import App
@ -85,9 +87,9 @@ class AppSite(Resource):
@edit_permission_required
@account_initialization_required
@get_app_model
def post(self, app_model: App):
@with_current_user
def post(self, current_user: Account, app_model: App):
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise NotFound
@ -134,8 +136,8 @@ class AppSiteAccessTokenReset(Resource):
@is_admin_or_owner_required
@account_initialization_required
@get_app_model
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, app_model: App):
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:

View File

@ -8,13 +8,14 @@ from pydantic import BaseModel, Field, field_validator
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import convert_datetime_to_date
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import AppMode
from models.account import Account
from models.model import App
@ -48,9 +49,8 @@ class DailyMessageStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -109,9 +109,8 @@ class DailyConversationStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -169,9 +168,8 @@ class DailyTerminalsStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -230,9 +228,8 @@ class DailyTokenCostStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -293,10 +290,9 @@ class AverageSessionInteractionStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("c.created_at")
@ -374,9 +370,8 @@ class UserSatisfactionRateStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("m.created_at")
@ -444,9 +439,8 @@ class AverageResponseTimeStatistic(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -505,8 +499,8 @@ class TokensPerSecondStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")

View File

@ -6,10 +6,11 @@ from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from models.enums import WorkflowRunTriggeredFrom
from models.model import App, AppMode
from repositories.factory import DifyAPIRepositoryFactory
@ -46,9 +47,8 @@ class WorkflowDailyRunsStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None
@ -86,9 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None
@ -126,9 +125,8 @@ class WorkflowDailyTokenCostStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None
@ -166,9 +164,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None

View File

@ -32,11 +32,11 @@ from controllers.console.wraps import (
decrypt_password_field,
email_password_login_enabled,
setup_required,
with_current_user,
)
from events.tenant_event import tenant_was_created
from libs.helper import EmailStr, extract_remote_ip
from libs.helper import timezone as validate_timezone_string
from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
@ -46,6 +46,7 @@ from libs.token import (
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from models.account import Account
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
from services.billing_service import BillingService
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
@ -172,9 +173,8 @@ class LoginApi(Resource):
class LogoutApi(Resource):
@setup_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
account = current_user
@with_current_user
def post(self, account: Account):
if isinstance(account, flask_login.AnonymousUserMixin):
response = make_response({"result": "success"})
else:

View File

@ -8,9 +8,16 @@ from werkzeug.exceptions import BadRequest
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from controllers.console.wraps import (
account_initialization_required,
only_edition_cloud,
setup_required,
with_current_tenant_id,
with_current_user,
)
from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.billing_service import BillingService
@ -32,8 +39,9 @@ class Subscription(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True))
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
@ -45,8 +53,9 @@ class Invoices(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_invoices(current_user.email, current_tenant_id)
@ -63,9 +72,8 @@ class PartnerTenants(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
def put(self, partner_key: str):
current_user, _ = current_account_with_tenant()
@with_current_user
def put(self, current_user: Account, partner_key: str):
try:
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
click_id = args.click_id

View File

@ -3,11 +3,18 @@ from flask_restx import Resource
from pydantic import BaseModel, Field
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.billing_service import BillingService
from .. import console_ns
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
from ..wraps import (
account_initialization_required,
only_edition_cloud,
setup_required,
with_current_tenant_id,
with_current_user,
)
class ComplianceDownloadQuery(BaseModel):
@ -29,8 +36,9 @@ class ComplianceApi(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True))
ip_address = extract_remote_ip(request)

View File

@ -9,7 +9,7 @@ from uuid import UUID
import sqlalchemy as sa
from flask import request, send_file
from flask_restx import Resource, marshal
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import asc, desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
@ -34,14 +34,16 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.document_fields import (
document_fields,
document_status_fields,
document_with_segments_fields,
DocumentMetadataResponse,
DocumentResponse,
DocumentStatusListResponse,
DocumentStatusResponse,
normalize_enum,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
from libs.datetime_utils import naive_utc_now
from libs.helper import to_timestamp
from libs.helper import dump_response, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
@ -74,12 +76,6 @@ from ..wraps import (
logger = logging.getLogger(__name__)
def _normalize_enum(value: Any) -> Any:
if isinstance(value, str) or value is None:
return value
return getattr(value, "value", value)
class DatasetResponse(ResponseModel):
id: str
name: str
@ -93,7 +89,7 @@ class DatasetResponse(ResponseModel):
@field_validator("data_source_type", "indexing_technique", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> Any:
return _normalize_enum(value)
return normalize_enum(value)
@field_validator("created_at", mode="before")
@classmethod
@ -101,61 +97,10 @@ class DatasetResponse(ResponseModel):
return to_timestamp(value)
class DocumentMetadataResponse(ResponseModel):
id: str
name: str
type: str
value: str | None = None
class DocumentResponse(ResponseModel):
id: str
position: int | None = None
data_source_type: str | None = None
data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict")
data_source_detail_dict: Any = None
dataset_process_rule_id: str | None = None
name: str
created_from: str | None = None
created_by: str | None = None
created_at: int | None = None
tokens: int | None = None
indexing_status: str | None = None
error: str | None = None
enabled: bool | None = None
disabled_at: int | None = None
disabled_by: str | None = None
archived: bool | None = None
display_status: str | None = None
word_count: int | None = None
hit_count: int | None = None
doc_form: str | None = None
doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details")
summary_index_status: str | None = None
need_summary: bool | None = None
@field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> Any:
return _normalize_enum(value)
@field_validator("doc_metadata", mode="before")
@classmethod
def _normalize_doc_metadata(cls, value: Any) -> list[Any]:
if value is None:
return []
return value
@field_validator("created_at", "disabled_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class DocumentWithSegmentsResponse(DocumentResponse):
process_rule_dict: Any = None
completed_segments: int | None = None
total_segments: int | None = None
completed_segments: int | None = Field(default=None, exclude_if=lambda value: value is None)
total_segments: int | None = Field(default=None, exclude_if=lambda value: value is None)
class DatasetAndDocumentResponse(ResponseModel):
@ -190,6 +135,14 @@ class DocumentDatasetListParam(BaseModel):
fetch_val: str = Field("false", alias="fetch")
class DocumentWithSegmentsListResponse(ResponseModel):
data: list[DocumentWithSegmentsResponse]
has_more: bool
limit: int
total: int
page: int
register_schema_models(
console_ns,
KnowledgeConfig,
@ -200,13 +153,19 @@ register_schema_models(
GenerateSummaryPayload,
DocumentMetadataUpdatePayload,
DocumentBatchDownloadZipPayload,
)
register_response_schema_models(
console_ns,
SimpleResultMessageResponse,
SimpleResultResponse,
UrlResponse,
DatasetResponse,
DocumentMetadataResponse,
DocumentResponse,
DocumentWithSegmentsResponse,
DatasetAndDocumentResponse,
DocumentWithSegmentsListResponse,
)
register_response_schema_models(console_ns, SimpleResultMessageResponse, SimpleResultResponse, UrlResponse)
class DocumentResource(Resource):
@ -312,7 +271,11 @@ class DatasetDocumentListApi(Resource):
"status": "Filter documents by display status",
}
)
@console_ns.response(200, "Documents retrieved successfully")
@console_ns.response(
200,
"Documents retrieved successfully",
console_ns.models[DocumentWithSegmentsListResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@ -425,18 +388,15 @@ class DatasetDocumentListApi(Resource):
)
document.completed_segments = completed_segments
document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields)
else:
data = marshal(documents, document_fields)
response = {
"data": data,
"data": documents,
"has_more": len(documents) == limit,
"limit": limit,
"total": paginated_documents.total,
"page": page,
}
return response
return dump_response(DocumentWithSegmentsListResponse, response)
@setup_required
@login_required
@ -482,9 +442,7 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
return DatasetAndDocumentResponse.model_validate(
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
).model_dump(mode="json")
return dump_response(DatasetAndDocumentResponse, {"dataset": dataset, "documents": documents, "batch": batch})
@setup_required
@login_required
@ -567,9 +525,7 @@ class DatasetInitApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
return DatasetAndDocumentResponse.model_validate(
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
).model_dump(mode="json")
return dump_response(DatasetAndDocumentResponse, {"dataset": dataset, "documents": documents, "batch": batch})
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
@ -742,6 +698,9 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
@console_ns.route("/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
class DocumentBatchIndexingStatusApi(DocumentResource):
@console_ns.response(
200, "Indexing status retrieved successfully", console_ns.models[DocumentStatusListResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@ -784,9 +743,8 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status}
return data
documents_status.append(document_dict)
return dump_response(DocumentStatusListResponse, {"data": documents_status})
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
@ -794,7 +752,9 @@ class DocumentIndexingStatusApi(DocumentResource):
@console_ns.doc("get_document_indexing_status")
@console_ns.doc(description="Get document indexing status")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.response(200, "Indexing status retrieved successfully")
@console_ns.response(
200, "Indexing status retrieved successfully", console_ns.models[DocumentStatusResponse.__name__]
)
@console_ns.response(404, "Document not found")
@setup_required
@login_required
@ -839,7 +799,7 @@ class DocumentIndexingStatusApi(DocumentResource):
"completed_segments": completed_segments,
"total_segments": total_segments,
}
return marshal(document_dict, document_status_fields)
return dump_response(DocumentStatusResponse, document_dict)
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
@ -1304,7 +1264,7 @@ class DocumentRenameApi(DocumentResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json")
return dump_response(DocumentResponse, document)
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")

View File

@ -1,11 +1,12 @@
import uuid
from typing import Literal
from typing import cast as type_cast
from uuid import UUID
from flask import request
from flask_restx import Resource, marshal
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import String, cast, func, or_, select
from sqlalchemy import String, case, cast, func, literal, or_, select
from sqlalchemy.dialects.postgresql import JSONB
from werkzeug.exceptions import Forbidden, NotFound
@ -13,7 +14,12 @@ import services
from configs import dify_config
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.common.schema import (
query_params_from_model,
query_params_from_request,
register_response_schema_models,
register_schema_models,
)
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import (
@ -34,9 +40,17 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.base import ResponseModel
from fields.segment_fields import child_chunk_fields, segment_fields
from fields.segment_fields import (
ChildChunkDetailResponse,
ChildChunkListResponse,
ChildChunkResponse,
SegmentDetailResponse,
SegmentResponse,
segment_response_with_summary,
segment_responses_with_summaries,
)
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import escape_like_pattern
from libs.helper import dump_response, escape_like_pattern
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
@ -44,20 +58,10 @@ from services.dataset_service import DatasetService, DocumentService, SegmentSer
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
from services.summary_index_service import SummaryIndexService
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
def _get_segment_with_summary(segment, dataset_id):
"""Helper function to marshal segment and add summary information."""
from services.summary_index_service import SummaryIndexService
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
# Query summary for this segment (only enabled summaries)
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
@ -67,6 +71,16 @@ class SegmentListQuery(BaseModel):
page: int = Field(default=1, ge=1)
class SegmentIdListQuery(BaseModel):
segment_id: list[str] = Field(default_factory=list, description="Segment IDs")
class ChildChunkListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
keyword: str | None = None
page: int = Field(default=1, ge=1)
class SegmentCreatePayload(BaseModel):
content: str
answer: str | None = None
@ -92,13 +106,35 @@ class SegmentBatchImportStatusResponse(ResponseModel):
job_status: str
class ConsoleSegmentListResponse(ResponseModel):
data: list[SegmentResponse]
limit: int
total: int
total_pages: int
page: int
class ChildChunkBatchUpdateResponse(ResponseModel):
data: list[ChildChunkResponse]
class ChildChunkBatchUpdatePayload(BaseModel):
chunks: list[ChildChunkUpdateArgs]
class SegmentDocParams:
DATASET_DOCUMENT = {"dataset_id": "Dataset ID", "document_id": "Document ID"}
DATASET_DOCUMENT_ACTION = {**DATASET_DOCUMENT, "action": "Action"}
DATASET_DOCUMENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Segment ID"}
DATASET_DOCUMENT_PARENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Parent segment ID"}
DATASET_DOCUMENT_CHILD_CHUNK = {**DATASET_DOCUMENT_PARENT_SEGMENT, "child_chunk_id": "Child chunk ID"}
register_schema_models(
console_ns,
SegmentListQuery,
SegmentIdListQuery,
ChildChunkListQuery,
SegmentCreatePayload,
SegmentUpdatePayload,
BatchImportPayload,
@ -107,11 +143,24 @@ register_schema_models(
ChildChunkBatchUpdatePayload,
ChildChunkUpdateArgs,
)
register_response_schema_models(console_ns, SegmentBatchImportStatusResponse, SimpleResultResponse)
register_response_schema_models(
console_ns,
SegmentResponse,
ConsoleSegmentListResponse,
SegmentDetailResponse,
ChildChunkDetailResponse,
ChildChunkListResponse,
ChildChunkBatchUpdateResponse,
SegmentBatchImportStatusResponse,
SimpleResultResponse,
)
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
class DatasetDocumentSegmentListApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@console_ns.doc(params=query_params_from_model(SegmentListQuery))
@console_ns.response(200, "Segments retrieved successfully", console_ns.models[ConsoleSegmentListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@ -134,12 +183,7 @@ class DatasetDocumentSegmentListApi(Resource):
if not document:
raise NotFound("Document not found.")
args = SegmentListQuery.model_validate(
{
**request.args.to_dict(),
"status": request.args.getlist("status"),
}
)
args = query_params_from_request(SegmentListQuery, list_fields=("status",))
page = args.page
limit = min(args.limit, 100)
@ -169,12 +213,17 @@ class DatasetDocumentSegmentListApi(Resource):
# Use database-specific methods for JSON array search
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
# Guard with jsonb_typeof to avoid "cannot extract elements from a scalar" error
# when keywords is null or a non-array JSON value.
# Feed the set-returning function a JSON array in every row. Filtering in
# the subquery is not enough because PostgreSQL can still evaluate the
# SRF on scalar JSON before applying the predicate.
keywords_jsonb = cast(DocumentSegment.keywords, JSONB)
keywords_array = case(
(func.jsonb_typeof(keywords_jsonb) == "array", keywords_jsonb),
else_=cast(literal("[]"), JSONB),
)
keywords_condition = func.array_to_string(
func.array(
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
.where(func.jsonb_typeof(cast(DocumentSegment.keywords, JSONB)) == "array")
select(func.jsonb_array_elements_text(keywords_array))
.correlate(DocumentSegment)
.scalar_subquery()
),
@ -200,38 +249,30 @@ class DatasetDocumentSegmentListApi(Resource):
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
# Query summaries for all segments in this page (batch query for efficiency)
segment_ids = [segment.id for segment in segments.items]
summaries = {}
segment_list = list(segments.items)
segment_ids = [segment.id for segment in segment_list]
summaries: dict[str, str | None] = {}
if segment_ids:
from services.summary_index_service import SummaryIndexService
summary_records = SummaryIndexService.get_segments_summaries(
segment_ids=segment_ids, dataset_id=dataset_id_str
)
# Only include enabled summaries (already filtered by service)
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
# Add summary to each segment
segments_with_summary = []
for segment in segments.items:
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
segment_dict["summary"] = summaries.get(segment.id)
segments_with_summary.append(segment_dict)
response = {
"data": segments_with_summary,
"data": segment_responses_with_summaries(segment_list, summaries),
"limit": limit,
"total": segments.total,
"total_pages": segments.pages,
"page": page,
}
return response, 200
return dump_response(ConsoleSegmentListResponse, response), 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
@console_ns.response(204, "Segments deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID):
current_user, _ = current_account_with_tenant()
@ -263,6 +304,8 @@ class DatasetDocumentSegmentListApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
class DatasetDocumentSegmentApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_ACTION)
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
@setup_required
@login_required
@account_initialization_required
@ -316,11 +359,12 @@ class DatasetDocumentSegmentApi(Resource):
SegmentService.update_segments_status(segment_ids, action, dataset, document)
except Exception as e:
raise InvalidActionError(str(e))
return {"result": "success"}, 200
return dump_response(SimpleResultResponse, {"result": "success"}), 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
class DatasetDocumentSegmentAddApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@setup_required
@login_required
@account_initialization_required
@ -328,6 +372,7 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
@console_ns.response(200, "Segment created successfully", console_ns.models[SegmentDetailResponse.__name__])
def post(self, dataset_id: UUID, document_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@ -367,18 +412,25 @@ class DatasetDocumentSegmentAddApi(Resource):
payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
segment = type_cast(DocumentSegment, SegmentService.create_segment(payload_dict, document, dataset))
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
response = {
"data": segment_response_with_summary(segment, summary.summary_content if summary else None),
"doc_form": document.doc_form,
}
return dump_response(SegmentDetailResponse, response), 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
class DatasetDocumentSegmentUpdateApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
@console_ns.response(200, "Segment updated successfully", console_ns.models[SegmentDetailResponse.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@ -435,12 +487,18 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
response = {
"data": segment_response_with_summary(segment, summary.summary_content if summary else None),
"doc_form": document.doc_form,
}
return dump_response(SegmentDetailResponse, response), 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@console_ns.response(204, "Segment deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@ -518,11 +576,11 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
try:
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = f"segment_batch_import_{str(job_id)}"
indexing_cache_key = f"segment_batch_import_{job_id}"
# send batch add segments task
redis_client.setnx(indexing_cache_key, "waiting")
batch_create_segment_to_index_task.delay(
str(job_id),
job_id,
upload_file_id,
dataset_id_str,
document_id_str,
@ -531,7 +589,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
)
except Exception as e:
return {"error": str(e)}, 500
return {"job_id": job_id, "job_status": "waiting"}, 200
return dump_response(SegmentBatchImportStatusResponse, {"job_id": job_id, "job_status": "waiting"}), 200
@console_ns.response(200, "Batch import status", console_ns.models[SegmentBatchImportStatusResponse.__name__])
@setup_required
@ -546,11 +604,13 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if cache_result is None:
raise ValueError("The job does not exist.")
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
response = {"job_id": job_id, "job_status": cache_result.decode()}
return dump_response(SegmentBatchImportStatusResponse, response), 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks")
class ChildChunkAddApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
@setup_required
@login_required
@account_initialization_required
@ -558,6 +618,7 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
@console_ns.response(200, "Child chunk created successfully", console_ns.models[ChildChunkDetailResponse.__name__])
def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@ -608,8 +669,11 @@ class ChildChunkAddApi(Resource):
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
@console_ns.doc(params=query_params_from_model(ChildChunkListQuery))
@console_ns.response(200, "Child chunks retrieved successfully", console_ns.models[ChildChunkListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@ -637,13 +701,7 @@ class ChildChunkAddApi(Resource):
)
if not segment:
raise NotFound("Segment not found.")
args = SegmentListQuery.model_validate(
{
"limit": request.args.get("limit", default=20, type=int),
"keyword": request.args.get("keyword"),
"page": request.args.get("page", default=1, type=int),
}
)
args = query_params_from_request(ChildChunkListQuery, use_defaults_for_malformed_ints=True)
page = args.page
limit = min(args.limit, 100)
@ -652,19 +710,27 @@ class ChildChunkAddApi(Resource):
child_chunks = SegmentService.get_child_chunks(
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
)
return {
"data": marshal(child_chunks.items, child_chunk_fields),
response = {
"data": child_chunks.items,
"total": child_chunks.total,
"total_pages": child_chunks.pages,
"page": page,
"limit": limit,
}, 200
}
return dump_response(ChildChunkListResponse, response), 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
@console_ns.response(
200,
"Child chunks updated successfully",
console_ns.models[ChildChunkBatchUpdateResponse.__name__],
)
@console_ns.expect(console_ns.models[ChildChunkBatchUpdatePayload.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@ -702,7 +768,7 @@ class ChildChunkAddApi(Resource):
child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
return dump_response(ChildChunkBatchUpdateResponse, {"data": child_chunks}), 200
@console_ns.route(
@ -713,6 +779,7 @@ class ChildChunkUpdateApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
@console_ns.response(204, "Child chunk deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@ -743,7 +810,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id_str),
ChildChunk.id == child_chunk_id_str,
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id_str,
@ -770,7 +837,9 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
@console_ns.response(200, "Child chunk updated successfully", console_ns.models[ChildChunkDetailResponse.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@ -800,7 +869,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id_str),
ChildChunk.id == child_chunk_id_str,
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id_str,
@ -822,4 +891,4 @@ class ChildChunkUpdateApi(Resource):
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200

View File

@ -7,14 +7,20 @@ from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import MetadataUpdatePayload
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
with_current_user,
)
from fields.dataset_fields import (
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
from libs.helper import dump_response
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation,
@ -43,8 +49,8 @@ class DatasetMetadataCreateApi(Resource):
@enterprise_license_required
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
def post(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
dataset_id_str = str(dataset_id)
@ -80,8 +86,8 @@ class DatasetMetadataApi(Resource):
@enterprise_license_required
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
def patch(self, dataset_id: UUID, metadata_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def patch(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
name = payload.name
@ -100,8 +106,8 @@ class DatasetMetadataApi(Resource):
@account_initialization_required
@enterprise_license_required
@console_ns.response(204, "Metadata deleted successfully")
def delete(self, dataset_id: UUID, metadata_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -137,8 +143,8 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@account_initialization_required
@enterprise_license_required
@console_ns.response(204, "Action completed successfully")
def post(self, dataset_id: UUID, action: Literal["enable", "disable"]):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable"]):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -165,8 +171,8 @@ class DocumentMetadataEditApi(Resource):
204,
"Documents metadata updated successfully",
)
def post(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@ -198,12 +198,13 @@ class DatasourceAuth(Resource):
def get(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
_, current_tenant_id = current_account_with_tenant()
user, current_tenant_id = current_account_with_tenant()
datasources = datasource_provider_service.list_datasource_credentials(
tenant_id=current_tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
user=user,
)
return {"result": datasources}, 200

View File

@ -11,10 +11,13 @@ from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
@ -35,9 +38,10 @@ class CreateRagPipelineDatasetApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
@ -79,10 +83,10 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.create_empty_rag_pipeline_dataset(

View File

@ -10,6 +10,7 @@ from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from extensions.ext_database import db
from fields.rag_pipeline_fields import (
@ -17,7 +18,8 @@ from fields.rag_pipeline_fields import (
pipeline_import_check_dependencies_fields,
pipeline_import_fields,
)
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from models.dataset import Pipeline
from services.entities.dsl_entities import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@ -62,9 +64,9 @@ class RagPipelineImportApi(Resource):
@edit_permission_required
@marshal_with(pipeline_import_model)
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
def post(self):
@with_current_user
def post(self, current_user: Account):
# Check user role first
current_user, _ = current_account_with_tenant()
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
# Use a plain Session so that caught exceptions inside the service
@ -105,9 +107,8 @@ class RagPipelineImportConfirmApi(Resource):
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_model)
def post(self, import_id: str):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, import_id: str):
with Session(db.engine, expire_on_commit=False) as session:
import_service = RagPipelineDslService(session)
account = current_user

View File

@ -18,6 +18,7 @@ from controllers.console.app.error import (
)
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user_id
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
@ -135,20 +136,18 @@ class CompletionApi(InstalledAppResource):
)
class CompletionStopApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self, installed_app: InstalledApp, task_id: str):
@with_current_user_id
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.EXPLORE,
user_id=current_user.id,
user_id=current_user_id,
app_mode=AppMode.value_of(app_model.mode),
)
@ -215,7 +214,8 @@ class ChatApi(InstalledAppResource):
)
class ChatStopApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self, installed_app: InstalledApp, task_id: str):
@with_current_user_id
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
@ -223,13 +223,10 @@ class ChatStopApi(InstalledAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.EXPLORE,
user_id=current_user.id,
user_id=current_user_id,
app_mode=app_mode,
)

View File

@ -12,14 +12,19 @@ from controllers.common.fields import SimpleMessageResponse, SimpleResultMessage
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from graphon.file import helpers as file_helpers
from libs.datetime_utils import naive_utc_now
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
from libs.login import login_required
from models import Account, App, InstalledApp, RecommendedApp
from models.model import IconType
from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService
@ -131,9 +136,10 @@ class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
def get(self):
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
current_user, current_tenant_id = current_account_with_tenant()
if query.app_id:
installed_apps = db.session.scalars(
@ -212,7 +218,8 @@ class InstalledAppsListApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
@console_ns.response(200, "Success", console_ns.models[SimpleMessageResponse.__name__])
def post(self):
@with_current_tenant_id
def post(self, current_tenant_id: str):
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
recommended_app = db.session.scalar(
@ -221,8 +228,6 @@ class InstalledAppsListApi(Resource):
if recommended_app is None:
raise NotFound("Recommended app not found")
_, current_tenant_id = current_account_with_tenant()
app = db.session.get(App, payload.app_id)
if app is None:
@ -262,8 +267,8 @@ class InstalledAppApi(InstalledAppResource):
"""
@console_ns.response(204, "App uninstalled successfully")
def delete(self, installed_app: InstalledApp):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def delete(self, current_tenant_id: str, installed_app: InstalledApp):
if installed_app.app_owner_tenant_id == current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant")

View File

@ -2,13 +2,36 @@ from flask_restx import Resource
from werkzeug.exceptions import Unauthorized
from controllers.common.schema import register_response_schema_models
from fields.base import ResponseModel
from libs.helper import dump_response
from libs.login import current_user, login_required
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
from services.feature_service import (
FeatureModel,
FeatureService,
LimitationModel,
SystemFeatureModel,
)
from . import console_ns
from .wraps import account_initialization_required, cloud_utm_record, setup_required, with_current_tenant_id
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
class TrialModelsResponse(ResponseModel):
trial_models: list[str]
class AppDslVersionResponse(ResponseModel):
app_dsl_version: str
register_response_schema_models(
console_ns,
AppDslVersionResponse,
FeatureModel,
LimitationModel,
SystemFeatureModel,
TrialModelsResponse,
)
@console_ns.route("/features")
@ -54,6 +77,43 @@ class FeatureVectorSpaceApi(Resource):
return FeatureService.get_vector_space(current_tenant_id).model_dump()
@console_ns.route("/trial-models")
class TrialModelsApi(Resource):
@console_ns.doc("get_trial_models")
@console_ns.doc(description="Get hosted trial model provider configuration")
@console_ns.response(
200,
"Success",
console_ns.models[TrialModelsResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
def get(self):
"""Get hosted trial model provider configuration for model-provider pages."""
return dump_response(
TrialModelsResponse,
{"trial_models": FeatureService.get_trial_models()},
)
@console_ns.route("/app-dsl-version")
class AppDslVersionApi(Resource):
@console_ns.doc("get_app_dsl_version")
@console_ns.doc(description="Get current app DSL version")
@console_ns.response(
200,
"Success",
console_ns.models[AppDslVersionResponse.__name__],
)
def get(self):
"""Get current app DSL version for workflow clipboard compatibility."""
return dump_response(
AppDslVersionResponse,
{"app_dsl_version": FeatureService.get_app_dsl_version()},
)
@console_ns.route("/system-features")
class SystemFeatureApi(Resource):
@console_ns.doc("get_system_features")

View File

@ -12,8 +12,15 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
model_validate,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
@ -22,8 +29,8 @@ from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App
from libs.login import login_required
from models import Account, App
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
@ -33,6 +40,8 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
register_schema_models(console_ns, HumanInputFormSubmitPayload)
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
@ -45,9 +54,8 @@ class ConsoleHumanInputFormApi(Resource):
"""Console API for getting human input form definition."""
@staticmethod
def _ensure_console_access(form: Form):
_, current_tenant_id = current_account_with_tenant()
def _ensure_console_access(form: Form, current_tenant_id: str) -> None:
"""Ensure a console form token resolves only inside the current tenant."""
if form.tenant_id != current_tenant_id:
raise NotFoundError("App not found")
@ -59,7 +67,8 @@ class ConsoleHumanInputFormApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, form_token: str):
@with_current_tenant_id
def get(self, current_tenant_id: str, form_token: str):
"""
Get human input form definition by form token.
@ -70,13 +79,23 @@ class ConsoleHumanInputFormApi(Resource):
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
self._ensure_console_access(form, current_tenant_id)
return _jsonify_form_definition(form)
@account_initialization_required
@login_required
def post(self, form_token: str):
@with_current_user
@with_current_tenant_id
@model_validate(HumanInputFormSubmitPayload)
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
def post(
self,
payload: HumanInputFormSubmitPayload,
current_tenant_id: str,
current_user: Account,
form_token: str,
):
"""
Submit human input form by form token.
@ -90,15 +109,12 @@ class ConsoleHumanInputFormApi(Resource):
"action": "Approve"
}
"""
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
current_user, _ = current_account_with_tenant()
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
self._ensure_console_access(form, current_tenant_id)
self._ensure_console_recipient_type(form)
recipient_type = form.recipient_type
# The type checker is not smart enought to validate the following invariant.
@ -122,7 +138,9 @@ class ConsoleWorkflowEventsApi(Resource):
@account_initialization_required
@login_required
def get(self, workflow_run_id: str):
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account, workflow_run_id: str):
"""
Get workflow execution events stream after resume.
@ -130,8 +148,6 @@ class ConsoleWorkflowEventsApi(Resource):
Returns Server-Sent Events stream.
"""
user, tenant_id = current_account_with_tenant()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(

View File

@ -13,7 +13,7 @@ from controllers.common.errors import (
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import with_current_user
from core.helper import ssrf_proxy
from core.file import remote_fetcher
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from graphon.file import helpers as file_helpers
@ -36,9 +36,9 @@ class GetRemoteFileInfo(Resource):
@login_required
def get(self, url: str):
decoded_url = helpers.decode_remote_url(url, request.query_string)
resp = ssrf_proxy.head(decoded_url)
resp = remote_fetcher.make_request("HEAD", decoded_url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp = remote_fetcher.make_request("GET", decoded_url, timeout=3)
resp.raise_for_status()
return RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
@ -58,9 +58,9 @@ class RemoteFileUpload(Resource):
# Try to fetch remote file metadata/content first
try:
resp = ssrf_proxy.head(url=url)
resp = remote_fetcher.make_request("HEAD", url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
resp = remote_fetcher.make_request("GET", url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
# Normalize into a user-friendly error message expected by tests
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
@ -74,7 +74,7 @@ class RemoteFileUpload(Resource):
raise FileTooLargeError()
# Load content if needed
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
content = resp.content if resp.request.method == "GET" else remote_fetcher.make_request("GET", url).content
try:
upload_file = FileService(db.engine).upload_file(

View File

@ -18,7 +18,7 @@ from controllers.common.fields import (
SimpleResultResponse,
VerificationTokenResponse,
)
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
@ -48,7 +48,7 @@ from fields.base import ResponseModel
from fields.member_fields import Account as AccountResponse
from graphon.file import helpers as file_helpers
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, extract_remote_ip, timezone, to_timestamp
from libs.helper import EmailStr, dump_response, extract_remote_ip, timezone, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from models.account import AccountStatus, InvitationCodeStatus
@ -173,7 +173,6 @@ class CheckEmailUniquePayload(BaseModel):
register_schema_models(
console_ns,
AccountResponse,
AccountInitPayload,
AccountNamePayload,
AccountAvatarPayload,
@ -245,6 +244,7 @@ register_schema_models(
)
register_response_schema_models(
console_ns,
AccountResponse,
AvatarUrlResponse,
SimpleResultDataResponse,
SimpleResultResponse,
@ -329,9 +329,9 @@ class AccountNameApi(Resource):
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@console_ns.expect(console_ns.models[AccountAvatarQuery.__name__])
@console_ns.doc("get_account_avatar")
@console_ns.doc(description="Get account avatar url")
@console_ns.doc(params=query_params_from_model(AccountAvatarQuery))
@console_ns.response(200, "Success", console_ns.models[AvatarUrlResponse.__name__])
@setup_required
@login_required
@ -342,7 +342,7 @@ class AccountAvatarApi(Resource):
avatar = args.avatar
if avatar.startswith(("http://", "https://")):
return {"avatar_url": avatar}
return dump_response(AvatarUrlResponse, {"avatar_url": avatar})
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == avatar).limit(1))
if upload_file is None:
@ -355,7 +355,7 @@ class AccountAvatarApi(Resource):
raise NotFound("Avatar file not found")
avatar_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
return {"avatar_url": avatar_url}
return dump_response(AvatarUrlResponse, {"avatar_url": avatar_url})
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
@setup_required

View File

@ -1,9 +1,15 @@
from flask_restx import Resource, fields
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.agent_service import AgentService
@ -19,14 +25,10 @@ class AgentProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
user_id = user.id
tenant_id = current_tenant_id
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
return jsonable_encoder(AgentService.list_agent_providers(current_user.id, current_tenant_id))
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
@ -42,6 +44,7 @@ class AgentProviderApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, provider_name: str):
return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))

View File

@ -4,11 +4,16 @@ from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import current_account_with_tenant, login_required
from models import TenantAccountRole
from libs.login import login_required
from models import Account, TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
@ -29,8 +34,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, provider: str):
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
@ -72,8 +78,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str, config_id: str):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, provider: str, config_id: str):
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()

View File

@ -25,12 +25,13 @@ from controllers.console.wraps import (
account_initialization_required,
is_allow_transfer_owner,
setup_required,
with_current_user,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.member_fields import AccountWithRole, AccountWithRoleList
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account, TenantAccountJoin, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
@ -136,8 +137,8 @@ class MemberListApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
def get(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account):
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
@ -154,7 +155,8 @@ class MemberInviteEmailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = MemberInvitePayload.model_validate(payload)
@ -163,7 +165,6 @@ class MemberInviteEmailApi(Resource):
interface_language = args.language
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
@ -223,8 +224,8 @@ class MemberCancelInviteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, member_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, member_id: UUID):
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.get(Account, str(member_id))
@ -256,14 +257,14 @@ class MemberUpdateRoleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def put(self, member_id: UUID):
@with_current_user
def put(self, current_user: Account, member_id: UUID):
payload = console_ns.payload or {}
args = MemberRoleUpdatePayload.model_validate(payload)
new_role = args.role
if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not _is_role_enabled(new_role, current_user.current_tenant.id):
@ -297,8 +298,8 @@ class DatasetOperatorMemberListApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
def get(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account):
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
@ -317,13 +318,13 @@ class SendOwnerTransferEmailApi(Resource):
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = OwnerTransferEmailPayload.model_validate(payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
current_user, _ = current_account_with_tenant()
# check if the current user is the owner of the workspace
if not current_user.current_tenant:
raise ValueError("No current tenant")
@ -355,11 +356,11 @@ class OwnerTransferCheckApi(Resource):
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = OwnerTransferCheckPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
@ -399,12 +400,12 @@ class OwnerTransfer(Resource):
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self, member_id: UUID):
@with_current_user
def post(self, current_user: Account, member_id: UUID):
payload = console_ns.payload or {}
args = OwnerTransferPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):

View File

@ -8,12 +8,19 @@ from pydantic import BaseModel, Field, field_validator
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import uuid_value
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
@ -95,10 +102,8 @@ class ModelProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
@with_current_tenant_id
def get(self, tenant_id: str):
payload = request.args.to_dict(flat=True)
args = ParserModelList.model_validate(payload)
@ -114,9 +119,8 @@ class ModelProviderCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
# if credential_id is not provided, return current used credential
payload = request.args.to_dict(flat=True)
args = ParserCredentialId.model_validate(payload)
@ -133,8 +137,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserCredentialCreate.model_validate(payload)
@ -157,9 +161,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def put(self, current_tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserCredentialUpdate.model_validate(payload)
@ -184,8 +187,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def delete(self, current_tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserCredentialDelete.model_validate(payload)
@ -205,8 +208,8 @@ class ModelProviderCredentialSwitchApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserCredentialSwitch.model_validate(payload)
@ -225,8 +228,8 @@ class ModelProviderValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserCredentialValidate.model_validate(payload)
@ -280,11 +283,8 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserPreferredProviderType.model_validate(payload)
@ -301,10 +301,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, provider: str):
if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid")
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
data = BillingService.get_model_provider_payment_link(
provider_name=provider,

View File

@ -13,12 +13,14 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import uuid_value
from libs.login import login_required
from models import Account
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
@ -193,7 +195,7 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, provider):
def get(self, tenant_id: str, provider: str):
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
@ -269,8 +271,9 @@ class ModelProviderModelCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
def get(self, tenant_id: str, user: Account, provider: str):
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True))
model_provider_service = ModelProviderService()
@ -292,9 +295,13 @@ class ModelProviderModelCredentialApi(Resource):
)
if args.config_from == "predefined-model":
# Only the predefined-model branch needs visibility filtering by user.
# The account is injected once by the handler and only passed into the
# service branch that needs user-scoped credential visibility.
available_credentials = model_provider_service.get_provider_available_credentials(
tenant_id=tenant_id,
provider=provider,
user=user,
)
else:
available_credentials = model_provider_service.get_provider_model_available_credentials(

View File

@ -69,6 +69,7 @@ class BuiltinToolAddPayload(BaseModel):
credentials: dict[str, Any]
name: str | None = Field(default=None, max_length=30)
type: CredentialType
visibility: str | None = None
class BuiltinToolUpdatePayload(BaseModel):
@ -277,7 +278,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
@ -293,7 +294,7 @@ class ToolBuiltinProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@ -306,7 +307,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {})
@ -324,7 +325,7 @@ class ToolBuiltinProviderAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@ -338,6 +339,7 @@ class ToolBuiltinProviderAddApi(Resource):
credentials=payload.credentials,
name=payload.name,
api_type=CredentialType.of(payload.type),
visibility=payload.visibility,
)
@ -348,7 +350,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@ -370,13 +372,20 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
_, tenant_id = current_account_with_tenant()
def get(self, provider: str):
user, tenant_id = current_account_with_tenant()
# Optional list of credential IDs to include even if visibility would hide them
# (used when a workflow/agent node still references another member's only_me credential).
include_credential_ids = request.args.getlist("include_credential_ids") or [
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
]
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credentials(
tenant_id=tenant_id,
provider_name=provider,
user=user,
include_credential_ids=include_credential_ids or None,
)
)
@ -384,7 +393,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/icon")
class ToolBuiltinProviderIconApi(Resource):
@setup_required
def get(self, provider):
def get(self, provider: str):
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@ -784,7 +793,7 @@ class ToolPluginOAuthApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
@ -822,7 +831,7 @@ class ToolPluginOAuthApi(Resource):
@console_ns.route("/oauth/plugin/<path:provider>/tool/callback")
class ToolOAuthCallback(Resource):
@setup_required
def get(self, provider):
def get(self, provider: str):
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
@ -859,7 +868,7 @@ class ToolOAuthCallback(Resource):
if not credentials:
raise Exception("the plugin credentials failed")
# add credentials to database
# add credentials to database — OAuth tokens default to only_me since they're personal
BuiltinToolManageService.add_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
@ -867,6 +876,7 @@ class ToolOAuthCallback(Resource):
credentials=dict(credentials),
expires_at=expires_at,
api_type=CredentialType.OAUTH2,
visibility="only_me",
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
@ -878,7 +888,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
return BuiltinToolManageService.set_default_provider(
@ -910,7 +920,7 @@ class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
@ -919,7 +929,7 @@ class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
@ -931,7 +941,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
@ -945,13 +955,18 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
_, tenant_id = current_account_with_tenant()
def get(self, provider: str):
user, tenant_id = current_account_with_tenant()
include_credential_ids = request.args.getlist("include_credential_ids") or [
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
]
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
tenant_id=tenant_id,
provider=provider,
user=user,
include_credential_ids=include_credential_ids or None,
)
)
@ -1151,7 +1166,7 @@ class ToolMCPDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id):
def get(self, provider_id: str):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
@ -1180,7 +1195,7 @@ class ToolMCPUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id):
def get(self, provider_id: str):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)

View File

@ -77,7 +77,7 @@ class TriggerProviderIconApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
@ -103,7 +103,7 @@ class TriggerProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
"""Get info for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -119,15 +119,18 @@ class TriggerSubscriptionListApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
"""List all trigger subscriptions for the current tenant's provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
return jsonable_encoder(
TriggerProviderService.list_trigger_provider_subscriptions(
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
tenant_id=user.current_tenant_id,
provider_id=TriggerProviderID(provider),
user=user,
)
)
except ValueError as e:
@ -146,7 +149,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
"""Add a new subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -175,7 +178,7 @@ class TriggerSubscriptionBuilderGetApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
def get(self, provider: str, subscription_builder_id: str):
"""Get a subscription instance for a trigger provider"""
return jsonable_encoder(
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
@ -191,7 +194,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
def post(self, provider: str, subscription_builder_id: str):
"""Verify and update a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -223,7 +226,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
def post(self, provider: str, subscription_builder_id: str):
"""Update a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -257,7 +260,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
def get(self, provider: str, subscription_builder_id: str):
"""Get the request logs for a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -280,7 +283,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
def post(self, provider: str, subscription_builder_id: str):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -404,7 +407,7 @@ class TriggerOAuthAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
"""Initiate OAuth authorization flow for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -486,7 +489,7 @@ class TriggerOAuthAuthorizeApi(Resource):
@console_ns.route("/oauth/plugin/<path:provider>/trigger/callback")
class TriggerOAuthCallbackApi(Resource):
@setup_required
def get(self, provider):
def get(self, provider: str):
"""Handle OAuth callback for trigger provider"""
context_id = request.cookies.get("context_id")
if not context_id:
@ -554,7 +557,7 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
"""Get OAuth client configuration for a provider"""
user = current_user
assert user.current_tenant_id is not None
@ -600,7 +603,7 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
"""Configure custom OAuth client for a provider"""
user = current_user
assert user.current_tenant_id is not None
@ -626,7 +629,7 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider):
def delete(self, provider: str):
"""Remove custom OAuth client configuration"""
user = current_user
assert user.current_tenant_id is not None
@ -654,7 +657,7 @@ class TriggerSubscriptionVerifyApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_id):
def post(self, provider: str, subscription_id: str):
"""Verify credentials for an existing subscription (edit mode only)"""
user = current_user
assert user.current_tenant_id is not None

View File

@ -29,7 +29,7 @@ from controllers.console.wraps import (
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import TimestampField, to_timestamp
from libs.helper import TimestampField, dump_response, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
from services.account_service import TenantService
@ -56,6 +56,11 @@ class WorkspaceCustomConfigPayload(BaseModel):
replace_webapp_logo: str | None = None
class WorkspaceCustomConfigResponse(ResponseModel):
remove_webapp_brand: bool | None = None
replace_webapp_logo: str | None = None
class WorkspaceInfoPayload(BaseModel):
name: str
@ -69,7 +74,7 @@ class TenantInfoResponse(ResponseModel):
role: str | None = None
in_trial: bool | None = None
trial_end_reason: str | None = None
custom_config: dict | None = None
custom_config: WorkspaceCustomConfigResponse | None = None
trial_credits: int | None = None
trial_credits_used: int | None = None
next_credit_reset_date: int | None = None
@ -101,9 +106,13 @@ register_schema_models(
SwitchWorkspacePayload,
WorkspaceCustomConfigPayload,
WorkspaceInfoPayload,
TenantInfoResponse,
)
register_response_schema_models(console_ns, WorkspacePermissionResponse)
register_response_schema_models(
console_ns,
TenantInfoResponse,
WorkspaceCustomConfigResponse,
WorkspacePermissionResponse,
)
provider_fields = {
"provider_name": fields.String,
@ -238,13 +247,7 @@ class TenantApi(Resource):
else:
raise Unauthorized("workspace is archived")
return (
TenantInfoResponse.model_validate(
WorkspaceService.get_tenant_info(tenant),
from_attributes=True,
).model_dump(mode="json"),
200,
)
return dump_response(TenantInfoResponse, WorkspaceService.get_tenant_info(tenant)), 200
@console_ns.route("/workspaces/switch")

View File

@ -7,7 +7,9 @@ from functools import wraps
from typing import Concatenate
from flask import abort, request
from pydantic import BaseModel, ValidationError
from sqlalchemy import select
from werkzeug.exceptions import UnprocessableEntity
from configs import dify_config
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
@ -512,9 +514,79 @@ def with_current_tenant_id[T, **P, R](
def with_current_user[T, **P, R](
view: Callable[Concatenate[T, Account, P], R],
) -> Callable[Concatenate[T, P], R]:
"""Inject the current authenticated Account into the handler as the first argument after self.
Usage::
class MyResource(Resource):
@login_required
@with_current_user
def get(self, current_user: Account):
...
"""
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
current_user, _ = current_account_with_tenant()
return view(self, current_user, *args, **kwargs)
return decorated
def with_current_user_id[T, **P, R](
view: Callable[Concatenate[T, str, P], R],
) -> Callable[Concatenate[T, P], R]:
"""Inject the current authenticated user's ID (as a string) into the handler.
Use this when the handler only needs the user ID and not the full Account object.
Usage::
class MyResource(Resource):
@login_required
@with_current_user_id
def get(self, current_user_id: str):
...
"""
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
current_user, _ = current_account_with_tenant()
return view(self, str(current_user.id), *args, **kwargs)
return decorated
def model_validate[T, M: BaseModel, **P, R](
model: type[M],
) -> Callable[
[Callable[Concatenate[T, M, P], R]],
Callable[Concatenate[T, P], R],
]:
"""Validate request data and inject the model instance as the first arg after self.
Source is determined by HTTP method:
GET/DELETE -> request.args
POST/PUT/PATCH -> JSON body
"""
def decorator(
view: Callable[Concatenate[T, M, P], R],
) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
if request.method in ("GET", "DELETE"):
raw = request.args.to_dict(flat=True)
else:
raw = request.get_json(silent=True) or {}
try:
validated = model.model_validate(raw)
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
return view(self, validated, *args, **kwargs)
return wrapper
return decorator

View File

@ -45,6 +45,15 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
.limit(1)
)
else:
# Try id first (preserves the original "explicit end-user
# id → that specific user" semantics for callers that pass
# a known EndUser.id). Fall back to session_id so daemon-
# supplied session UUIDs dedup against the row created on
# the first Reverse Invocation call — without this, an
# id-only lookup never matched (create writes user_id to
# session_id, id is auto-generated) and a fresh EndUser
# was created per call, breaking multi-turn chat
# continuation (see #36736).
user_model = session.scalar(
select(EndUser)
.where(
@ -53,6 +62,15 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
)
.limit(1)
)
if user_model is None:
user_model = session.scalar(
select(EndUser)
.where(
EndUser.session_id == user_id,
EndUser.tenant_id == tenant_id,
)
.limit(1)
)
if not user_model:
user_model = EndUser(

View File

@ -7,7 +7,7 @@ from hmac import new as hmac_new
from flask import abort, request
from configs import dify_config
from extensions.ext_database import db
from core.db.session_factory import session_factory
from models.model import EndUser
@ -44,6 +44,8 @@ def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""Inject an EndUser for valid inner API HMAC auth, otherwise pass the request through unchanged."""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.INNER_API:
@ -72,9 +74,9 @@ def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P,
if signature_base64 != token:
return view(*args, **kwargs)
kwargs["user"] = db.session.get(EndUser, user_id)
return view(*args, **kwargs)
with session_factory.create_session() as session:
kwargs["user"] = session.get(EndUser, user_id)
return view(*args, **kwargs)
return decorated

View File

@ -147,7 +147,7 @@ class AppDescribeApi(AppReadResource):
class AppListApi(Resource):
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@auth_router.guard_workspace(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, *, auth_data: AuthData):
try:
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))

View File

@ -1,11 +1,13 @@
from __future__ import annotations
from controllers.openapi.auth.conditions import (
EDITION_CE,
EDITION_EE,
HAS_ALLOWED_ROLES,
LOADED_APP_IS_PRIVATE,
PATH_HAS_APP_ID,
WEBAPP_AUTH_ENABLED,
WORKSPACE_MEMBERSHIP_REQUIRED,
WORKSPACE_SCOPED,
)
from controllers.openapi.auth.data import Edition
from controllers.openapi.auth.flow import When
@ -15,14 +17,18 @@ from controllers.openapi.auth.prepare import (
load_app,
load_app_access_mode,
load_tenant,
load_tenant_from_request,
load_workspace_role,
resolve_external_user,
)
from controllers.openapi.auth.verify import (
check_acl,
check_app_access,
check_membership,
check_app_api_enabled,
check_private_app_permission,
check_scope,
check_workspace_member,
check_workspace_mismatch,
check_workspace_role,
)
from libs.oauth_bearer import TokenType
@ -30,13 +36,17 @@ account_pipeline = AuthPipeline(
prepare=[
When(PATH_HAS_APP_ID, then=load_app),
When(PATH_HAS_APP_ID, then=load_tenant),
load_account, # all tokens here are account tokens
When(WORKSPACE_MEMBERSHIP_REQUIRED, then=load_tenant_from_request),
load_account,
When(WORKSPACE_SCOPED, then=load_workspace_role),
When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode),
],
auth=[
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
check_scope,
When(EDITION_CE & PATH_HAS_APP_ID, then=check_membership),
When(EDITION_EE & PATH_HAS_APP_ID & ~WEBAPP_AUTH_ENABLED, then=check_app_access),
When(WORKSPACE_SCOPED, then=check_workspace_member),
When(PATH_HAS_APP_ID, then=check_workspace_mismatch),
When(HAS_ALLOWED_ROLES, then=check_workspace_role),
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
],
@ -50,6 +60,7 @@ external_sso_pipeline = AuthPipeline(
When(PATH_HAS_APP_ID, then=load_app_access_mode),
],
auth=[
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
check_scope,
When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl),
When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission),

View File

@ -50,4 +50,11 @@ EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
WORKSPACE_MEMBERSHIP_REQUIRED = request_cond(lambda ctx: ctx.workspace_membership)
HAS_ALLOWED_ROLES = request_cond(lambda ctx: ctx.allowed_roles is not None)
# Caller must belong to the resolved tenant: either an app-scoped path (tenant
# from the app) or an explicit workspace-membership path (tenant from request).
WORKSPACE_SCOPED = PATH_HAS_APP_ID | WORKSPACE_MEMBERSHIP_REQUIRED
LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE)

View File

@ -9,7 +9,7 @@ from werkzeug.exceptions import InternalServerError
from configs import dify_config
from libs.oauth_bearer import Scope, TokenType
from models.account import Account, Tenant
from models.account import Account, Tenant, TenantAccountRole
from models.model import App, EndUser
from services.enterprise.enterprise_service import WebAppAccessMode
@ -41,6 +41,8 @@ class RequestContext(BaseModel):
token_type: TokenType
scope: Scope | None = None
path_params: dict[str, str]
workspace_membership: bool = False
allowed_roles: frozenset[TenantAccountRole] | None = None
class AuthData(BaseModel):
@ -56,10 +58,14 @@ class AuthData(BaseModel):
external_identity: ExternalIdentity | None = None
path_params: dict[str, str] = Field(default_factory=dict)
allowed_roles: frozenset[TenantAccountRole] | None = None
app: App | None = None
tenant: Tenant | None = None
app_access_mode: WebAppAccessMode | None = None
tenant_role: TenantAccountRole | None = None
caller: Account | EndUser | None = None
caller_kind: Literal["account", "end_user"] | None = None

View File

@ -34,6 +34,7 @@ from libs.oauth_bearer import (
reset_auth_ctx,
set_auth_ctx,
)
from models.account import TenantAccountRole
from services.feature_service import FeatureService, LicenseStatus
@ -56,11 +57,15 @@ class AuthPipeline:
view: Callable,
*,
scope: Scope | None,
workspace_membership: bool = False,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Any:
req_ctx = RequestContext(
token_type=identity.token_type,
scope=scope,
path_params=dict(request.view_args or {}),
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
data = AuthData(
@ -71,6 +76,7 @@ class AuthPipeline:
scopes=frozenset(identity.scopes),
tenants=dict(identity.verified_tenants),
required_scope=scope,
allowed_roles=allowed_roles,
path_params=dict(req_ctx.path_params),
external_identity=(
ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer)
@ -121,6 +127,41 @@ class PipelineRouter:
scope: Scope | None = None,
allowed_token_types: frozenset[TokenType] | None = None,
edition: frozenset[Edition] | None = None,
workspace_membership: bool = False,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Callable:
return self._make_decorator(
scope=scope,
allowed_token_types=allowed_token_types,
edition=edition,
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
def guard_workspace(
self,
*,
scope: Scope | None = None,
allowed_token_types: frozenset[TokenType] | None = None,
edition: frozenset[Edition] | None = None,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Callable:
return self._make_decorator(
scope=scope,
allowed_token_types=allowed_token_types,
edition=edition,
workspace_membership=True,
allowed_roles=allowed_roles,
)
def _make_decorator(
self,
*,
scope: Scope | None,
allowed_token_types: frozenset[TokenType] | None,
edition: frozenset[Edition] | None,
workspace_membership: bool,
allowed_roles: frozenset[TenantAccountRole] | None,
) -> Callable:
def decorator(view: Callable) -> Callable:
@wraps(view)
@ -132,6 +173,8 @@ class PipelineRouter:
scope=scope,
allowed_token_types=allowed_token_types,
edition=edition,
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
return decorated
@ -147,6 +190,8 @@ class PipelineRouter:
scope: Scope | None,
allowed_token_types: frozenset[TokenType] | None,
edition: frozenset[Edition] | None,
workspace_membership: bool = False,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Any:
# 404 not 403 — this edition doesn't expose the feature at all
if edition is not None and current_edition() not in edition:
@ -182,7 +227,15 @@ class PipelineRouter:
if not license_checked and Edition.EE in route.required_edition:
_check_license()
return route.pipeline._run(identity, args, kwargs, view, scope=scope)
return route.pipeline._run(
identity,
args,
kwargs,
view,
scope=scope,
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool:

View File

@ -1,5 +1,8 @@
from __future__ import annotations
import uuid
from flask import request
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound, Unauthorized
from controllers.openapi.auth.data import AuthData
@ -13,16 +16,18 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppAcce
def load_app(data: AuthData) -> None:
if data.app is not None:
return
app_id = data.path_params["app_id"]
app = AppService.get_app_by_id(db.session, app_id)
if not app or app.status != "normal":
raise NotFound("app not found")
if not app.enable_api:
raise Forbidden("service_api_disabled")
data.app = app
def load_tenant(data: AuthData) -> None:
if data.tenant is not None:
return
if data.app is None:
raise InternalServerError("pipeline_invariant_violated: app not loaded before load_tenant")
tenant = TenantService.get_tenant_by_id(db.session, str(data.app.tenant_id))
@ -31,7 +36,25 @@ def load_tenant(data: AuthData) -> None:
data.tenant = tenant
def load_tenant_from_request(data: AuthData) -> None:
if data.tenant is not None:
return
workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
if not workspace_id:
raise NotFound("workspace not found")
try:
uuid.UUID(workspace_id)
except ValueError:
raise NotFound("workspace not found")
tenant = TenantService.get_tenant_by_id(db.session, workspace_id)
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
raise NotFound("workspace not found")
data.tenant = tenant
def load_account(data: AuthData) -> None:
if data.caller is not None:
return
account = AccountService.get_account_by_id(db.session, str(data.account_id))
if account is None:
raise Unauthorized("account not found")
@ -41,6 +64,19 @@ def load_account(data: AuthData) -> None:
data.caller_kind = "account"
def load_workspace_role(data: AuthData) -> None:
if data.tenant_role is not None:
return
if data.tenant is None or data.account_id is None:
return
if data.caller is not None and getattr(data.caller, "status", None) != "active":
return
role = TenantService.get_account_role_in_tenant(db.session, str(data.account_id), str(data.tenant.id))
if role is None:
return
data.tenant_role = role
def resolve_external_user(data: AuthData) -> None:
if data.tenant is None or data.app is None or data.external_identity is None:
raise Unauthorized("missing context for external user resolution")

View File

@ -1,77 +0,0 @@
"""Workspace role gate.
Layered on top of `validate_bearer` + `accept_subjects(SubjectType.ACCOUNT)`
for routes whose access depends on the caller's `TenantAccountJoin.role`
in the workspace named by the `workspace_id` path parameter.
Usage::
@openapi_ns.route("/workspaces/<string:workspace_id>/members")
class Members(Resource):
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
@require_workspace_role() # any member
def get(self, workspace_id: str): ...
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def post(self, workspace_id: str): ...
Non-member callers get 404 (matching `GET /openapi/v1/workspaces/<id>`)
so workspace IDs do not leak across tenants. A member without one of the
allowed roles gets 403.
"""
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TypeVar
from werkzeug.exceptions import Forbidden, NotFound
from extensions.ext_database import db
from libs.oauth_bearer import try_get_auth_ctx
from models.account import TenantAccountRole
from services.account_service import TenantService
F = TypeVar("F", bound=Callable[..., object])
def require_workspace_role(*allowed_roles: TenantAccountRole) -> Callable[[F], F]:
"""Gate a route on the caller's role in ``workspace_id``.
Pass no roles to require only membership. Pass one or more roles to
require the caller's role be in that set.
"""
allowed = frozenset(allowed_roles)
def deco(fn: F) -> F:
@wraps(fn)
def wrapper(*args: object, **kwargs: object) -> object:
ctx = try_get_auth_ctx()
if ctx is None or ctx.account_id is None:
raise RuntimeError(
"require_workspace_role called without account-bearer context; "
"stack validate_bearer + accept_subjects(SubjectType.ACCOUNT) above it"
)
workspace_id = kwargs.get("workspace_id")
if not workspace_id:
raise RuntimeError("require_workspace_role expects a 'workspace_id' route parameter")
role = TenantService.get_account_role_in_tenant(db.session, str(ctx.account_id), str(workspace_id))
if role is None:
raise NotFound("workspace not found")
if allowed and role not in allowed:
raise Forbidden("insufficient workspace role")
return fn(*args, **kwargs)
return wrapper # type: ignore[return-value]
return deco

View File

@ -1,10 +1,11 @@
from __future__ import annotations
from werkzeug.exceptions import Forbidden, Unauthorized
from flask import request
from werkzeug.exceptions import Forbidden, NotFound, UnprocessableEntity
from controllers.openapi.auth.data import AuthData
from extensions.ext_database import db
from libs.oauth_bearer import Scope, TokenType, check_workspace_membership
from libs.oauth_bearer import Scope, TokenType
from services.account_service import AccountService, TenantService
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
@ -17,17 +18,39 @@ def check_scope(data: AuthData) -> None:
raise Forbidden("insufficient_scope")
def check_membership(data: AuthData) -> None:
def check_workspace_member(data: AuthData) -> None:
"""Assert the caller belongs to the resolved tenant.
`load_workspace_role` stashes the membership role (None when the caller is
not a member or is inactive). A missing membership surfaces as 404, not
403, so workspace IDs don't leak across tenants.
"""
if data.tenant_role is None:
raise NotFound("workspace not found")
def check_workspace_mismatch(data: AuthData) -> None:
if data.tenant is None:
raise Unauthorized("tenant unset")
if data.account_id is None:
raise Unauthorized("account_id unset")
check_workspace_membership(
account_id=data.account_id,
tenant_id=data.tenant.id,
token_hash=data.token_hash,
membership_cache=data.tenants,
)
return
request_workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
if request_workspace_id and request_workspace_id != str(data.tenant.id):
raise UnprocessableEntity("workspace_id does not match app's workspace")
def check_workspace_role(data: AuthData) -> None:
if data.allowed_roles is None:
return
if data.tenant_role is None:
raise NotFound("workspace not found")
if data.tenant_role not in data.allowed_roles:
raise Forbidden("insufficient workspace role")
def check_app_api_enabled(data: AuthData) -> None:
if data.app is None:
return
if not data.app.enable_api:
raise Forbidden("service_api_disabled")
def check_app_access(data: AuthData) -> None:

View File

@ -26,7 +26,12 @@ from werkzeug.exceptions import BadRequest
from configs import dify_config
from controllers.common.schema import query_params_from_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
AccountPayload,
@ -42,7 +47,6 @@ from controllers.openapi._models import (
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
from libs.rate_limit import (
LIMIT_APPROVE_CONSOLE,
@ -50,6 +54,7 @@ from libs.rate_limit import (
LIMIT_LOOKUP_PUBLIC,
rate_limit,
)
from models import Account
from services.account_service import TenantService
from services.oauth_device_flow import (
ACCOUNT_ISSUER_SENTINEL,
@ -206,11 +211,12 @@ class DeviceApproveApi(Resource):
@account_initialization_required
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, tenant: str, account: Account):
payload = _validate_json(DeviceMutateRequest)
user_code = payload.user_code.strip().upper()
account, tenant = current_account_with_tenant()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)

View File

@ -5,9 +5,8 @@ endpoints. Account bearers (dfoa_) see every tenant they're a member of.
External SSO bearers (dfoe_) have no account_id and so see an empty list —
that matches /openapi/v1/account.
Member-management endpoints are gated by both `accept_subjects` (SSO out)
and `require_workspace_role` (membership / role lookup against the path's
``workspace_id``).
Member-management endpoints use ``guard_workspace`` which enforces
workspace membership and optional role requirements via the auth pipeline.
"""
from __future__ import annotations
@ -37,7 +36,6 @@ from controllers.openapi._models import (
)
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.role_gate import require_workspace_role
from extensions.ext_database import db
from libs.oauth_bearer import Scope, TokenType
from models import Account, Tenant, TenantAccountJoin
@ -152,8 +150,7 @@ class WorkspaceSwitchApi(Resource):
"""
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role()
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def post(self, workspace_id: str, *, auth_data: AuthData):
account = _load_account(auth_data.account_id)
@ -179,8 +176,7 @@ class WorkspaceMembersApi(Resource):
@openapi_ns.doc(params=query_params_from_model(MemberListQuery))
@openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__])
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role()
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, workspace_id: str, *, auth_data: AuthData):
try:
query = MemberListQuery.model_validate(request.args.to_dict(flat=True))
@ -202,8 +198,11 @@ class WorkspaceMembersApi(Resource):
@openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__])
@openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__])
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
@auth_router.guard_workspace(
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
def post(self, workspace_id: str, *, auth_data: AuthData):
payload = _validate_body(MemberInvitePayload)
inviter = _load_account(auth_data.account_id)
@ -253,8 +252,11 @@ class WorkspaceMemberApi(Resource):
"""
@openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__])
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
@auth_router.guard_workspace(
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
operator = _load_account(auth_data.account_id)
tenant = _load_tenant(workspace_id)
@ -284,8 +286,11 @@ class WorkspaceMemberRoleApi(Resource):
@openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__])
@openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__])
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
@auth_router.guard_workspace(
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
payload = _validate_body(MemberRoleUpdatePayload)
operator = _load_account(auth_data.account_id)

View File

@ -6,7 +6,7 @@ from flask_restx import Resource
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field, TypeAdapter
from controllers.common.schema import register_schema_models
from controllers.common.schema import query_params_from_model, register_schema_models
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
@ -32,8 +32,19 @@ class AnnotationReplyActionPayload(BaseModel):
embedding_model_name: str = Field(description="Embedding model name")
class AnnotationListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, description="Number of annotations per page")
keyword: str = Field(default="", description="Keyword to search annotations")
register_schema_models(
service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList
service_api_ns,
AnnotationCreatePayload,
AnnotationReplyActionPayload,
AnnotationListQuery,
Annotation,
AnnotationList,
)
@ -100,6 +111,7 @@ class AnnotationReplyActionStatusApi(Resource):
class AnnotationListApi(Resource):
@service_api_ns.doc("list_annotations")
@service_api_ns.doc(description="List annotations for the application")
@service_api_ns.doc(params=query_params_from_model(AnnotationListQuery))
@service_api_ns.doc(
responses={
200: "Annotations retrieved successfully",
@ -114,18 +126,18 @@ class AnnotationListApi(Resource):
@validate_app_token
def get(self, app_model: App):
"""List annotations for the application."""
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)
query = AnnotationListQuery.model_validate(request.args.to_dict(flat=True))
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app_model.id, query.page, query.limit, query.keyword
)
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response = AnnotationList(
data=annotation_models,
has_more=len(annotation_list) == limit,
limit=limit,
has_more=len(annotation_list) == query.limit,
limit=query.limit,
total=total,
page=page,
page=query.page,
)
return response.model_dump(mode="json")

View File

@ -41,6 +41,15 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
def _resolve_agent_app_streaming(*, app_mode: AppMode, response_mode: str | None) -> bool:
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
if app_mode != AppMode.AGENT:
return response_mode == "streaming"
if response_mode == "blocking":
raise BadRequest("Agent App only supports streaming response mode.")
return True
class CompletionRequestPayload(BaseModel):
inputs: dict[str, Any]
query: str = Field(default="")
@ -197,7 +206,7 @@ class ChatApi(Resource):
Supports conversation management and both blocking and streaming response modes.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
@ -207,7 +216,7 @@ class ChatApi(Resource):
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = payload.response_mode == "streaming"
streaming = _resolve_agent_app_streaming(app_mode=app_mode, response_mode=payload.response_mode)
try:
response = AppGenerateService.generate(
@ -262,7 +271,7 @@ class ChatStopApi(Resource):
def post(self, app_model: App, end_user: EndUser, task_id: str):
"""Stop a running chat message generation."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
AppTaskService.stop_task(

View File

@ -155,7 +155,7 @@ class ConversationApi(Resource):
Supports pagination using last_id and limit parameters.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
query_args = ConversationListQuery.model_validate(request.args.to_dict())
@ -199,7 +199,7 @@ class ConversationDetailApi(Resource):
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
"""Delete a specific conversation."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -228,7 +228,7 @@ class ConversationRenameApi(Resource):
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
"""Rename a conversation or auto-generate a name."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -56,7 +56,7 @@ class MessageListApi(Resource):
Retrieves messages with pagination support using first_id.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
query_args = MessageListQuery.model_validate(request.args.to_dict())
@ -167,7 +167,7 @@ class MessageSuggestedApi(Resource):
"""
message_id_str = str(message_id)
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
try:

View File

@ -12,7 +12,6 @@ from typing import Self
from uuid import UUID
from flask import request, send_file
from flask_restx import marshal
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
@ -27,7 +26,12 @@ from controllers.common.errors import (
UnsupportedFileTypeError,
)
from controllers.common.fields import UrlResponse
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
from controllers.common.schema import (
query_params_from_model,
register_enum_models,
register_response_schema_models,
register_schema_models,
)
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import (
@ -44,7 +48,13 @@ from core.errors.error import ProviderTokenNotInitError
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from fields.base import ResponseModel
from fields.document_fields import (
DocumentListResponse,
DocumentResponse,
DocumentStatusListResponse,
)
from libs.helper import dump_response
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import SegmentStatus
@ -107,6 +117,44 @@ class DocumentListQuery(BaseModel):
status: str | None = Field(default=None, description="Document status filter")
DOCUMENT_CREATE_BY_FILE_PARAMS = {
"dataset_id": "Dataset ID",
"file": {
"in": "formData",
"type": "file",
"required": True,
"description": "Document file to upload.",
},
"data": {
"in": "formData",
"type": "string",
"required": False,
"description": "Optional JSON string with document creation settings.",
},
}
DOCUMENT_UPDATE_BY_FILE_PARAMS = {
"dataset_id": "Dataset ID",
"document_id": "Document ID",
"file": {
"in": "formData",
"type": "file",
"required": False,
"description": "Replacement document file.",
},
"data": {
"in": "formData",
"type": "string",
"required": False,
"description": "Optional JSON string with document update settings.",
},
}
class DocumentAndBatchResponse(ResponseModel):
document: DocumentResponse
batch: str
register_enum_models(service_api_ns, RetrievalMethod)
register_schema_models(
@ -121,7 +169,14 @@ register_schema_models(
PreProcessingRule,
Segmentation,
)
register_response_schema_models(service_api_ns, UrlResponse)
register_response_schema_models(
service_api_ns,
UrlResponse,
DocumentResponse,
DocumentAndBatchResponse,
DocumentListResponse,
DocumentStatusListResponse,
)
def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[str, object], int]:
@ -188,8 +243,7 @@ def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
return dump_response(DocumentAndBatchResponse, {"document": document, "batch": batch}), 200
def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
@ -248,8 +302,7 @@ def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
return dump_response(DocumentAndBatchResponse, {"document": document, "batch": batch}), 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create-by-text")
@ -267,6 +320,9 @@ class DocumentAddByTextApi(DatasetApiResource):
400: "Bad request - invalid parameters",
}
)
@service_api_ns.response(
200, "Document created successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
@ -296,6 +352,9 @@ class DeprecatedDocumentAddByTextApi(DatasetApiResource):
400: "Bad request - invalid parameters",
}
)
@service_api_ns.response(
200, "Document created successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
@ -319,6 +378,9 @@ class DocumentUpdateByTextApi(DatasetApiResource):
404: "Document not found",
}
)
@service_api_ns.response(
200, "Document updated successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
@ -347,6 +409,9 @@ class DeprecatedDocumentUpdateByTextApi(DatasetApiResource):
404: "Document not found",
}
)
@service_api_ns.response(
200, "Document updated successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
@ -363,7 +428,7 @@ class DocumentAddByFileApi(DatasetApiResource):
@service_api_ns.doc("create_document_by_file")
@service_api_ns.doc(description="Create a new document by uploading a file")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@service_api_ns.doc(consumes=["multipart/form-data"], params=DOCUMENT_CREATE_BY_FILE_PARAMS)
@service_api_ns.doc(
responses={
200: "Document created successfully",
@ -371,6 +436,9 @@ class DocumentAddByFileApi(DatasetApiResource):
400: "Bad request - invalid file or parameters",
}
)
@service_api_ns.response(
200, "Document created successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
@ -462,8 +530,7 @@ class DocumentAddByFileApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
return dump_response(DocumentAndBatchResponse, {"document": document, "batch": batch}), 200
def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
@ -539,8 +606,7 @@ def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
return documents_and_batch_fields, 200
return dump_response(DocumentAndBatchResponse, {"document": document, "batch": document.batch}), 200
@service_api_ns.route(
@ -558,7 +624,7 @@ class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
"Use PATCH /datasets/{dataset_id}/documents/{document_id} instead."
)
)
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(consumes=["multipart/form-data"], params=DOCUMENT_UPDATE_BY_FILE_PARAMS)
@service_api_ns.doc(
responses={
200: "Document updated successfully",
@ -566,6 +632,9 @@ class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
404: "Document not found",
}
)
@service_api_ns.response(
200, "Document updated successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
@ -577,7 +646,7 @@ class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
class DocumentListApi(DatasetApiResource):
@service_api_ns.doc("list_documents")
@service_api_ns.doc(description="List all documents in a dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@service_api_ns.doc(params={"dataset_id": "Dataset ID", **query_params_from_model(DocumentListQuery)})
@service_api_ns.doc(
responses={
200: "Documents retrieved successfully",
@ -585,6 +654,9 @@ class DocumentListApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
200, "Documents retrieved successfully", service_api_ns.models[DocumentListResponse.__name__]
)
def get(self, tenant_id, dataset_id: UUID):
dataset_id_str = str(dataset_id)
tenant_id = str(tenant_id)
@ -618,14 +690,14 @@ class DocumentListApi(DatasetApiResource):
)
response = {
"data": marshal(documents, document_fields),
"data": documents,
"has_more": len(documents) == query_params.limit,
"limit": query_params.limit,
"total": paginated_documents.total,
"page": query_params.page,
}
return response
return dump_response(DocumentListResponse, response)
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/download-zip")
@ -680,6 +752,11 @@ class DocumentIndexingStatusApi(DatasetApiResource):
404: "Dataset or documents not found",
}
)
@service_api_ns.response(
200,
"Indexing status retrieved successfully",
service_api_ns.models[DocumentStatusListResponse.__name__],
)
def get(self, tenant_id, dataset_id: UUID, batch: str):
dataset_id_str = str(dataset_id)
tenant_id = str(tenant_id)
@ -729,9 +806,8 @@ class DocumentIndexingStatusApi(DatasetApiResource):
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status}
return data
documents_status.append(document_dict)
return dump_response(DocumentStatusListResponse, {"data": documents_status})
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/download")
@ -890,7 +966,7 @@ class DocumentApi(DatasetApiResource):
@service_api_ns.doc("update_document_by_file")
@service_api_ns.doc(description="Update an existing document by uploading a file")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(consumes=["multipart/form-data"], params=DOCUMENT_UPDATE_BY_FILE_PARAMS)
@service_api_ns.doc(
responses={
200: "Document updated successfully",
@ -898,6 +974,9 @@ class DocumentApi(DatasetApiResource):
404: "Document not found",
}
)
@service_api_ns.response(
200, "Document updated successfully", service_api_ns.models[DocumentAndBatchResponse.__name__]
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID):

View File

@ -1,15 +1,18 @@
from typing import Any
from typing import cast
from uuid import UUID
from flask import request
from flask_restx import marshal
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ValidationError, field_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
from controllers.common.schema import register_schema_models
from controllers.common.schema import (
query_params_from_model,
query_params_from_request,
register_response_schema_models,
register_schema_models,
)
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import (
@ -22,10 +25,19 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from fields.segment_fields import child_chunk_fields, segment_fields
from fields.base import ResponseModel
from fields.segment_fields import (
ChildChunkDetailResponse,
ChildChunkListResponse,
SegmentDetailResponse,
SegmentResponse,
segment_response_with_summary,
segment_responses_with_summaries,
)
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import dump_response
from libs.login import current_account_with_tenant
from models.dataset import Dataset
from models.dataset import Dataset, DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
@ -34,35 +46,27 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from services.summary_index_service import SummaryIndexService
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
"""Marshal a single segment and enrich it with summary content."""
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
class SegmentCreateItemPayload(BaseModel):
content: str = Field(min_length=1)
answer: str | None = None
keywords: list[str] | None = None
attachment_ids: list[str] | None = None
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
"""Marshal multiple segments and enrich them with summary content (batch query)."""
segment_ids = [segment.id for segment in segments]
summaries: dict[str, str | None] = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
result: list[dict[str, Any]] = []
for segment in segments:
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict["summary"] = summaries.get(segment.id)
result.append(segment_dict)
return result
@field_validator("content")
@classmethod
def validate_content(cls, value: str) -> str:
if not value.strip():
raise ValueError("Content is empty")
return value
class SegmentCreatePayload(BaseModel):
segments: list[dict[str, Any]] | None = None
segments: list[SegmentCreateItemPayload] = Field(min_length=1)
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1)
page: int = Field(default=1, ge=1)
status: list[str] = Field(default_factory=list)
keyword: str | None = None
@ -77,9 +81,31 @@ class ChildChunkListQuery(BaseModel):
page: int = Field(default=1, ge=1)
class SegmentDocParams:
DATASET_DOCUMENT = {"dataset_id": "Dataset ID", "document_id": "Document ID"}
DATASET_DOCUMENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Segment ID"}
DATASET_DOCUMENT_PARENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Parent segment ID"}
DATASET_DOCUMENT_CHILD_CHUNK = {**DATASET_DOCUMENT_PARENT_SEGMENT, "child_chunk_id": "Child chunk ID"}
class SegmentCreateListResponse(ResponseModel):
data: list[SegmentResponse]
doc_form: str
class SegmentListResponse(ResponseModel):
data: list[SegmentResponse]
doc_form: str
total: int
has_more: bool
limit: int
page: int
register_schema_models(
service_api_ns,
SegmentCreatePayload,
SegmentCreateItemPayload,
SegmentListQuery,
SegmentUpdateArgs,
SegmentUpdatePayload,
@ -87,6 +113,15 @@ register_schema_models(
ChildChunkListQuery,
ChildChunkUpdatePayload,
)
register_response_schema_models(
service_api_ns,
SegmentResponse,
SegmentCreateListResponse,
SegmentListResponse,
SegmentDetailResponse,
ChildChunkDetailResponse,
ChildChunkListResponse,
)
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
@ -96,7 +131,7 @@ class SegmentApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[SegmentCreatePayload.__name__])
@service_api_ns.doc("create_segments")
@service_api_ns.doc(description="Create segments in a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@service_api_ns.doc(
responses={
200: "Segments created successfully",
@ -105,6 +140,11 @@ class SegmentApi(DatasetApiResource):
404: "Dataset or document not found",
}
)
@service_api_ns.response(
200,
"Segments created successfully",
service_api_ns.models[SegmentCreateListResponse.__name__],
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
@ -144,26 +184,35 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {})
if payload.segments is not None:
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
if segments_limit > 0 and len(payload.segments) > segments_limit:
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
try:
payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {})
except ValidationError as e:
return {"error": str(e)}, 400
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
if segments_limit > 0 and len(payload.segments) > segments_limit:
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
segment_items = [segment.model_dump(exclude_none=True) for segment in payload.segments]
for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {
"data": _marshal_segments_with_summary(segments, dataset_id_str),
"doc_form": document.doc_form,
}, 200
else:
return {"error": "Segments is required"}, 400
for args_item in segment_items:
SegmentService.segment_create_args_validate(args_item, document)
segments = cast(list[DocumentSegment], SegmentService.multi_create_segment(segment_items, document, dataset))
segment_ids = [segment.id for segment in segments]
summaries: dict[str, str | None] = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(
segment_ids=segment_ids, dataset_id=dataset_id_str
)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
response = {
"data": segment_responses_with_summaries(segments, summaries),
"doc_form": document.doc_form,
}
return dump_response(SegmentCreateListResponse, response), 200
@service_api_ns.expect(service_api_ns.models[SegmentListQuery.__name__])
@service_api_ns.doc("list_segments")
@service_api_ns.doc(description="List segments in a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@service_api_ns.doc(params=query_params_from_model(SegmentListQuery))
@service_api_ns.doc(
responses={
200: "Segments retrieved successfully",
@ -171,12 +220,22 @@ class SegmentApi(DatasetApiResource):
404: "Dataset or document not found",
}
)
@service_api_ns.response(
200,
"Segments retrieved successfully",
service_api_ns.models[SegmentListResponse.__name__],
)
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Get segments."""
# check dataset
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
args = query_params_from_request(
SegmentListQuery,
list_fields=("status",),
use_defaults_for_malformed_ints=True,
)
page = args.page
limit = args.limit
dataset_id_str = str(dataset_id)
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
@ -205,13 +264,6 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
args = SegmentListQuery.model_validate(
{
"status": request.args.getlist("status"),
"keyword": request.args.get("keyword"),
}
)
segments, total = SegmentService.get_segments(
document_id=document_id_str,
tenant_id=current_tenant_id,
@ -220,9 +272,16 @@ class SegmentApi(DatasetApiResource):
page=page,
limit=limit,
)
segment_ids = [segment.id for segment in segments]
summaries: dict[str, str | None] = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(
segment_ids=segment_ids, dataset_id=dataset_id_str
)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
response = {
"data": _marshal_segments_with_summary(segments, dataset_id_str),
"data": segment_responses_with_summaries(segments, summaries),
"doc_form": document.doc_form,
"total": total,
"has_more": len(segments) == limit,
@ -230,16 +289,14 @@ class SegmentApi(DatasetApiResource):
"page": page,
}
return response, 200
return dump_response(SegmentListResponse, response), 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
class DatasetSegmentApi(DatasetApiResource):
@service_api_ns.doc("delete_segment")
@service_api_ns.doc(description="Delete a specific segment")
@service_api_ns.doc(
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to delete"}
)
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@service_api_ns.doc(
responses={
204: "Segment deleted successfully",
@ -275,9 +332,7 @@ class DatasetSegmentApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
@service_api_ns.doc("update_segment")
@service_api_ns.doc(description="Update a specific segment")
@service_api_ns.doc(
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to update"}
)
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@service_api_ns.doc(
responses={
200: "Segment updated successfully",
@ -285,6 +340,7 @@ class DatasetSegmentApi(DatasetApiResource):
404: "Dataset, document, or segment not found",
}
)
@service_api_ns.response(200, "Segment updated successfully", service_api_ns.models[SegmentDetailResponse.__name__])
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
@ -328,13 +384,16 @@ class DatasetSegmentApi(DatasetApiResource):
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
return {
"data": _marshal_segment_with_summary(updated_segment, dataset_id_str),
summary = SummaryIndexService.get_segment_summary(segment_id=updated_segment.id, dataset_id=dataset_id_str)
response = {
"data": segment_response_with_summary(updated_segment, summary.summary_content if summary else None),
"doc_form": document.doc_form,
}, 200
}
return dump_response(SegmentDetailResponse, response), 200
@service_api_ns.doc("get_segment")
@service_api_ns.doc(description="Get a specific segment by ID")
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@service_api_ns.doc(
responses={
200: "Segment retrieved successfully",
@ -342,6 +401,11 @@ class DatasetSegmentApi(DatasetApiResource):
404: "Dataset, document, or segment not found",
}
)
@service_api_ns.response(
200,
"Segment retrieved successfully",
service_api_ns.models[SegmentDetailResponse.__name__],
)
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
@ -364,7 +428,12 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
return {"data": _marshal_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
response = {
"data": segment_response_with_summary(segment, summary.summary_content if summary else None),
"doc_form": document.doc_form,
}
return dump_response(SegmentDetailResponse, response), 200
@service_api_ns.route(
@ -376,9 +445,7 @@ class ChildChunkApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[ChildChunkCreatePayload.__name__])
@service_api_ns.doc("create_child_chunk")
@service_api_ns.doc(description="Create a new child chunk for a segment")
@service_api_ns.doc(
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"}
)
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
@service_api_ns.doc(
responses={
200: "Child chunk created successfully",
@ -386,6 +453,11 @@ class ChildChunkApi(DatasetApiResource):
404: "Dataset, document, or segment not found",
}
)
@service_api_ns.response(
200,
"Child chunk created successfully",
service_api_ns.models[ChildChunkDetailResponse.__name__],
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
@ -437,14 +509,12 @@ class ChildChunkApi(DatasetApiResource):
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
@service_api_ns.expect(service_api_ns.models[ChildChunkListQuery.__name__])
@service_api_ns.doc("list_child_chunks")
@service_api_ns.doc(description="List child chunks for a segment")
@service_api_ns.doc(
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"}
)
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
@service_api_ns.doc(params=query_params_from_model(ChildChunkListQuery))
@service_api_ns.doc(
responses={
200: "Child chunks retrieved successfully",
@ -452,6 +522,11 @@ class ChildChunkApi(DatasetApiResource):
404: "Dataset, document, or segment not found",
}
)
@service_api_ns.response(
200,
"Child chunks retrieved successfully",
service_api_ns.models[ChildChunkListResponse.__name__],
)
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Get child chunks."""
@ -475,13 +550,7 @@ class ChildChunkApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
args = ChildChunkListQuery.model_validate(
{
"limit": request.args.get("limit", default=20, type=int),
"keyword": request.args.get("keyword"),
"page": request.args.get("page", default=1, type=int),
}
)
args = query_params_from_request(ChildChunkListQuery, use_defaults_for_malformed_ints=True)
page = args.page
limit = min(args.limit, 100)
@ -491,13 +560,14 @@ class ChildChunkApi(DatasetApiResource):
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
)
return {
"data": marshal(child_chunks.items, child_chunk_fields),
response = {
"data": child_chunks.items,
"total": child_chunks.total,
"total_pages": child_chunks.pages,
"page": page,
"limit": limit,
}, 200
}
return dump_response(ChildChunkListResponse, response), 200
@service_api_ns.route(
@ -508,14 +578,7 @@ class DatasetChildChunkApi(DatasetApiResource):
@service_api_ns.doc("delete_child_chunk")
@service_api_ns.doc(description="Delete a specific child chunk")
@service_api_ns.doc(
params={
"dataset_id": "Dataset ID",
"document_id": "Document ID",
"segment_id": "Parent segment ID",
"child_chunk_id": "Child chunk ID to delete",
}
)
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
@service_api_ns.doc(
responses={
204: "Child chunk deleted successfully",
@ -549,7 +612,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Segment not found.")
# validate segment belongs to the specified document
if str(segment.document_id) != str(document_id_str):
if segment.document_id != document_id_str:
raise NotFound("Document not found.")
child_chunk_id_str = str(child_chunk_id)
@ -561,7 +624,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Child chunk not found.")
# validate child chunk belongs to the specified segment
if str(child_chunk.segment_id) != str(segment.id):
if child_chunk.segment_id != segment.id:
raise NotFound("Child chunk not found.")
try:
@ -574,14 +637,7 @@ class DatasetChildChunkApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
@service_api_ns.doc("update_child_chunk")
@service_api_ns.doc(description="Update a specific child chunk")
@service_api_ns.doc(
params={
"dataset_id": "Dataset ID",
"document_id": "Document ID",
"segment_id": "Parent segment ID",
"child_chunk_id": "Child chunk ID to update",
}
)
@service_api_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
@service_api_ns.doc(
responses={
200: "Child chunk updated successfully",
@ -589,6 +645,11 @@ class DatasetChildChunkApi(DatasetApiResource):
404: "Dataset, document, segment, or child chunk not found",
}
)
@service_api_ns.response(
200,
"Child chunk updated successfully",
service_api_ns.models[ChildChunkDetailResponse.__name__],
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
@ -616,7 +677,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Segment not found.")
# validate segment belongs to the specified document
if str(segment.document_id) != str(document_id_str):
if segment.document_id != document_id_str:
raise NotFound("Segment not found.")
child_chunk_id_str = str(child_chunk_id)
@ -628,7 +689,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Child chunk not found.")
# validate child chunk belongs to the specified segment
if str(child_chunk.segment_id) != str(segment.id):
if child_chunk.segment_id != segment.id:
raise NotFound("Child chunk not found.")
# validate args
@ -639,4 +700,4 @@ class DatasetChildChunkApi(DatasetApiResource):
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200

View File

@ -2,7 +2,7 @@ import logging
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from controllers.common.fields import SimpleResultResponse
@ -37,6 +37,15 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
def _resolve_agent_app_streaming(*, app_mode: AppMode, response_mode: str | None) -> bool:
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
if app_mode != AppMode.AGENT:
return response_mode == "streaming"
if response_mode == "blocking":
raise BadRequest("Agent App only supports streaming response mode.")
return True
class CompletionMessagePayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the completion")
query: str = Field(default="", description="Query text for completion")
@ -171,13 +180,13 @@ class ChatApi(WebApiResource):
)
def post(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
payload = ChatMessagePayload.model_validate(web_ns.payload or {})
args = payload.model_dump(exclude_none=True)
streaming = payload.response_mode == "streaming"
streaming = _resolve_agent_app_streaming(app_mode=app_mode, response_mode=payload.response_mode)
args["auto_generate_name"] = False
try:
@ -228,7 +237,7 @@ class ChatStopApi(WebApiResource):
@web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__])
def post(self, app_model: App, end_user: EndUser, task_id: str):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
AppTaskService.stop_task(

View File

@ -83,7 +83,7 @@ class ConversationListApi(WebApiResource):
)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
raw_args = request.args.to_dict()
@ -129,7 +129,7 @@ class ConversationApi(WebApiResource):
)
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -168,7 +168,7 @@ class ConversationRenameApi(WebApiResource):
)
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -206,7 +206,7 @@ class ConversationPinApi(WebApiResource):
@web_ns.response(200, "Conversation pinned successfully", web_ns.models[ResultResponse.__name__])
def patch(self, app_model: App, end_user: EndUser, c_id: UUID):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -237,7 +237,7 @@ class ConversationUnPinApi(WebApiResource):
@web_ns.response(200, "Conversation unpinned successfully", web_ns.models[ResultResponse.__name__])
def patch(self, app_model: App, end_user: EndUser, c_id: UUID):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -83,7 +83,7 @@ class MessageListApi(WebApiResource):
)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
raw_args = request.args.to_dict()
@ -225,7 +225,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
)
def get(self, app_model: App, end_user: EndUser, message_id: UUID):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
raise NotChatAppError()
message_id_str = str(message_id)

View File

@ -9,7 +9,7 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from core.helper import ssrf_proxy
from core.file import remote_fetcher
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from graphon.file import helpers as file_helpers
@ -60,10 +60,10 @@ class RemoteFileInfoApi(WebApiResource):
HTTPException: If the remote file cannot be accessed
"""
decoded_url = helpers.decode_remote_url(url, request.query_string)
resp = ssrf_proxy.head(decoded_url)
resp = remote_fetcher.make_request("HEAD", decoded_url)
if resp.status_code != httpx.codes.OK:
# failed back to get method
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp = remote_fetcher.make_request("GET", decoded_url, timeout=3)
resp.raise_for_status()
info = RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
@ -112,9 +112,9 @@ class RemoteFileUploadApi(WebApiResource):
url = str(payload.url)
try:
resp = ssrf_proxy.head(url=url)
resp = remote_fetcher.make_request("HEAD", url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
resp = remote_fetcher.make_request("GET", url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
@ -125,7 +125,7 @@ class RemoteFileUploadApi(WebApiResource):
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
content = resp.content if resp.request.method == "GET" else remote_fetcher.make_request("GET", url).content
try:
upload_file = FileService(db.engine).upload_file(

View File

@ -397,39 +397,40 @@ class CotAgentRunner(BaseAgentRunner, ABC):
current_scratchpad: AgentScratchpadUnit | None = None
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
action_str="",
action=None,
observation=None,
)
scratchpads.append(current_scratchpad)
if message.tool_calls:
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments),
match message:
case AssistantPromptMessage():
if not current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
action_str="",
action=None,
observation=None,
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except Exception:
logger.exception("Failed to parse tool call from assistant message")
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad.observation = message.content
else:
raise NotImplementedError("expected str type")
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
scratchpads = []
current_scratchpad = None
scratchpads.append(current_scratchpad)
if message.tool_calls:
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments),
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except Exception:
logger.exception("Failed to parse tool call from assistant message")
case ToolPromptMessage():
if current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad.observation = message.content
else:
raise NotImplementedError("expected str type")
case UserPromptMessage():
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
scratchpads = []
current_scratchpad = None
result.append(message)
result.append(message)
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))

View File

@ -39,16 +39,17 @@ class CotCompletionAgentRunner(CotAgentRunner):
historic_prompt = ""
for message in historic_prompt_messages:
if isinstance(message, UserPromptMessage):
historic_prompt += f"Question: {message.content}\n\n"
elif isinstance(message, AssistantPromptMessage):
if isinstance(message.content, str):
historic_prompt += message.content + "\n\n"
elif isinstance(message.content, list):
for content in message.content:
if not isinstance(content, TextPromptMessageContent):
continue
historic_prompt += content.data
match message:
case UserPromptMessage():
historic_prompt += f"Question: {message.content}\n\n"
case AssistantPromptMessage():
if isinstance(message.content, str):
historic_prompt += message.content + "\n\n"
elif isinstance(message.content, list):
for content in message.content:
if not isinstance(content, TextPromptMessageContent):
continue
historic_prompt += content.data
return historic_prompt

View File

@ -237,7 +237,9 @@ class EasyUIBasedAppConfig(AppConfig):
"""
app_model_config_from: EasyUIBasedAppModelConfigFrom
app_model_config_id: str
# Optional: an Agent App has no legacy app_model_config row, so the id may be
# absent (persistence then stores NULL for the conversation's id).
app_model_config_id: str | None = None
app_model_config_dict: dict[str, Any]
model: ModelConfigEntity
prompt_template: PromptTemplateEntity

View File

@ -1,4 +1,6 @@
import json
import re
from typing import Any
from core.app.app_config.entities import RagPipelineVariableEntity
from graphon.variables.input_entities import VariableEntity
@ -20,10 +22,32 @@ class WorkflowVariablesConfigManager:
# variables
for variable in user_input_form:
cls._normalize_json_schema(variable)
variables.append(VariableEntity.model_validate(variable))
return variables
@staticmethod
def _normalize_json_schema(variable: dict[str, Any]) -> None:
"""
Normalize ``json_schema`` from a JSON string to a dict.
The workflow graph is stored as JSON in the database. When a JSON
object variable carries a ``json_schema`` field, nested dicts are
preserved correctly, but older data or certain serialization paths
may store it as a JSON *string* instead of a native dict.
``VariableEntity.json_schema`` expects ``dict | None``, so we
deserialize the string here before handing it to Pydantic.
"""
json_schema = variable.get("json_schema")
if isinstance(json_schema, str):
try:
variable["json_schema"] = json.loads(json_schema)
except (json.JSONDecodeError, TypeError):
# Leave as-is; Pydantic validation will surface the error.
pass
@classmethod
def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]:
"""

View File

@ -134,17 +134,18 @@ class AdvancedChatAppGenerateResponseConverter(
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
match sub_stream_response:
case MessageEndStreamResponse():
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
case ErrorStreamResponse():
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
case NodeStartStreamResponse() | NodeFinishStreamResponse():
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
case _:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -161,16 +161,17 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
stream=stream,
)
if isinstance(user, EndUser):
self._user_id = user.id
user_session_id = user.session_id
self._created_by_role = CreatorUserRole.END_USER
elif isinstance(user, Account):
self._user_id = user.id
user_session_id = user.id
self._created_by_role = CreatorUserRole.ACCOUNT
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
match user:
case EndUser():
self._user_id = user.id
user_session_id = user.session_id
self._created_by_role = CreatorUserRole.END_USER
case Account():
self._user_id = user.id
user_session_id = user.id
self._created_by_role = CreatorUserRole.ACCOUNT
case _:
raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_system_variables = build_system_variables(
query=message.query,

View File

View File

@ -0,0 +1,106 @@
"""Build the EasyUI-style app config for an Agent App from its Agent Soul.
An Agent App has no legacy ``app_model_config``: its model / prompt live in the
bound Agent Soul snapshot. To ride the existing chat message + SSE pipeline we
synthesize an ``app_model_config``-shaped dict from the Soul (model + system
prompt) plus any app-level feature flags (opening statement, follow-up, …)
stored on ``app_model_config`` when present, then reuse the same sub-managers
the chat app type uses.
"""
from typing import Any, cast
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager
from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager
from core.app.app_config.entities import (
EasyUIBasedAppConfig,
EasyUIBasedAppModelConfigFrom,
PromptTemplateEntity,
)
from models.agent_config_entities import AgentSoulConfig
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
class AgentAppConfig(EasyUIBasedAppConfig):
"""Agent App config entity (EasyUI-shaped so it rides the chat pipeline).
``app_model_config_id`` is inherited as ``str | None``: an Agent App may have
no legacy ``app_model_config`` row, in which case persistence stores ``NULL``
for the conversation's ``app_model_config_id``.
"""
class AgentAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(
cls,
*,
app_model: App,
agent_soul: AgentSoulConfig,
app_model_config: AppModelConfig | None = None,
conversation: Conversation | None = None,
) -> AgentAppConfig:
"""Build the Agent App config from the Agent Soul (+ optional feature flags)."""
config_dict = cls._synthesize_config_dict(agent_soul, app_model_config)
# The synthesized dict is shaped like an app_model_config; the EasyUI
# sub-managers type their param as AppModelConfigDict (a TypedDict).
typed_config = cast(AppModelConfigDict, config_dict)
app_mode = AppMode.value_of(app_model.mode)
app_config = AgentAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=app_mode,
# The config is derived from the Agent Soul snapshot, not a legacy
# app_model_config row; the id is informational only.
app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG,
app_model_config_id=app_model_config.id if app_model_config else None,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(config=typed_config),
prompt_template=PromptTemplateConfigManager.convert(config=typed_config),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
dataset=DatasetConfigManager.convert(config=typed_config),
additional_features=cls.convert_features(config_dict, app_mode),
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
config=typed_config
)
return app_config
@staticmethod
def _synthesize_config_dict(
agent_soul: AgentSoulConfig,
app_model_config: AppModelConfig | None,
) -> dict[str, Any]:
"""Shape a Soul + feature flags into an ``app_model_config``-style dict.
Feature flags (opening statement / follow-up / tts / stt / citations /
moderation / annotation) come from ``app_model_config`` when present
(Q3: stored there), otherwise defaults; model + prompt always come from
the Agent Soul (the single source of truth for those).
"""
base: dict[str, Any] = dict(app_model_config.to_dict()) if app_model_config else {}
model = agent_soul.model
if model is not None:
base["model"] = {
"provider": model.model_provider,
"name": model.model,
"mode": "chat",
"completion_params": model.model_settings.model_dump(mode="json", exclude_none=True),
}
# The Agent Soul system prompt rides the EasyUI "simple" prompt slot; the
# agent backend is the real prompt authority, this only feeds the chat
# pipeline's bookkeeping (token counting, persistence).
base["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
base["pre_prompt"] = agent_soul.prompt.system_prompt or ""
# Agent App takes the user message directly; no completion-style inputs form.
base.setdefault("user_input_form", [])
return base
__all__ = ["AgentAppConfig", "AgentAppConfigManager"]

View File

@ -0,0 +1,327 @@
"""Agent App generator: orchestrate one conversation turn for an Agent App.
Mirrors the agent_chat generator (conversation + message + queue + streamed
response over the EasyUI chat pipeline), but the backing config comes from the
bound Agent Soul and the answer is produced by ``AgentAppRunner`` calling the
dify-agent backend rather than an in-process LLM/ReAct loop.
"""
from __future__ import annotations
import contextvars
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any
from flask import Flask, current_app
from sqlalchemy import select
from clients.agent_backend import AgentBackendRunEventAdapter
from clients.agent_backend.factory import create_agent_backend_run_client
from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.apps.agent_app.app_config_manager import AgentAppConfigManager
from core.app.apps.agent_app.app_runner import AgentAppRunner
from core.app.apps.agent_app.generate_response_converter import AgentAppGenerateResponseConverter
from core.app.apps.agent_app.runtime_request_builder import AgentAppRuntimeRequestBuilder
from core.app.apps.agent_app.session_store import AgentAppRuntimeSessionStore
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import (
AgentAppGenerateEntity,
DifyRunContext,
InvokeFrom,
UserFrom,
)
from core.app.llm.model_access import build_dify_model_access
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from models import Account, App, EndUser, Message
from models.agent import Agent, AgentConfigSnapshot, AgentScope, AgentSource, AgentStatus
from models.agent_config_entities import AgentSoulConfig
from services.conversation_service import ConversationService
logger = logging.getLogger(__name__)
class AgentAppGeneratorError(ValueError):
"""Raised when an Agent App turn cannot be set up."""
class AgentAppGenerator(MessageBasedAppGenerator):
def generate(
self,
*,
app_model: App,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[Mapping | str, None, None]:
if not streaming:
raise AgentAppGeneratorError("Agent App only supports streaming mode")
query = args.get("query")
if not isinstance(query, str) or not query.strip():
raise AgentAppGeneratorError("query is required")
query = query.replace("\x00", "")
inputs = args["inputs"]
# Resolve the bound roster Agent + its published Agent Soul snapshot.
agent, snapshot, agent_soul = self._resolve_agent(app_model)
conversation = None
conversation_id = args.get("conversation_id")
if conversation_id:
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
# Build the EasyUI-shaped config from the Agent Soul so the chat pipeline
# can persist usage; the answer itself comes from the agent backend.
app_model_config = app_model.app_model_config
app_config = AgentAppConfigManager.get_app_config(
app_model=app_model,
agent_soul=agent_soul,
app_model_config=app_model_config,
conversation=conversation,
)
model_conf = ModelConfigConverter.convert(app_config)
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
application_generate_entity = AgentAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=model_conf,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=[],
parent_message_id=(
args.get("parent_message_id")
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
else UUID_NIL
),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras={"auto_generate_conversation_name": args.get("auto_generate_name", True)},
call_depth=0,
trace_manager=trace_manager,
agent_id=agent.id,
agent_config_snapshot_id=snapshot.id,
)
conversation, message = self._init_generate_records(application_generate_entity, conversation)
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
context = contextvars.copy_context()
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
"user_from": UserFrom.ACCOUNT if isinstance(user, Account) else UserFrom.END_USER,
},
)
worker_thread.start()
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
return AgentAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,
*,
flask_app: Flask,
context: contextvars.Context,
application_generate_entity: AgentAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
user_from: UserFrom,
) -> None:
from libs.flask_utils import preserve_flask_contexts
with preserve_flask_contexts(flask_app, context_vars=context):
try:
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
app_config = application_generate_entity.app_config
# Apply app-level input guards (content moderation + annotation
# reply) before reaching the Agent backend, mirroring the EasyUI
# chat / agent-chat runners. These can short-circuit the turn.
app_model = db.session.get(App, app_config.app_id)
if app_model is None:
raise AgentAppGeneratorError("App not found")
handled, query = self._run_input_guards(
application_generate_entity=application_generate_entity,
app_model=app_model,
message=message,
queue_manager=queue_manager,
)
if handled:
return
dify_context = DifyRunContext(
tenant_id=app_config.tenant_id,
app_id=app_config.app_id,
user_id=application_generate_entity.user_id,
user_from=user_from,
invoke_from=application_generate_entity.invoke_from,
)
credentials_provider, _ = build_dify_model_access(dify_context)
_, _, agent_soul = self._resolve_agent_by_id(
tenant_id=app_config.tenant_id,
agent_id=application_generate_entity.agent_id,
snapshot_id=application_generate_entity.agent_config_snapshot_id,
)
runner = AgentAppRunner(
request_builder=AgentAppRuntimeRequestBuilder(credentials_provider=credentials_provider),
agent_backend_client=create_agent_backend_run_client(
base_url=dify_config.AGENT_BACKEND_BASE_URL,
use_fake=dify_config.AGENT_BACKEND_USE_FAKE,
fake_scenario=dify_config.AGENT_BACKEND_FAKE_SCENARIO,
),
event_adapter=AgentBackendRunEventAdapter(),
session_store=AgentAppRuntimeSessionStore(),
)
runner.run(
dify_context=dify_context,
agent_id=application_generate_entity.agent_id,
agent_config_snapshot_id=application_generate_entity.agent_config_snapshot_id,
agent_soul=agent_soul,
conversation_id=conversation.id,
query=query,
message_id=message.id,
model_name=application_generate_entity.model_conf.model,
queue_manager=queue_manager,
)
except GenerateTaskStoppedError:
pass
except Exception as e:
logger.exception("Unknown Error in Agent App generate worker")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
def _run_input_guards(
self,
*,
application_generate_entity: AgentAppGenerateEntity,
app_model: App,
message: Message,
queue_manager: AppQueueManager,
) -> tuple[bool, str]:
"""Apply input moderation + annotation reply before the backend call.
Returns ``(handled, query)``: when ``handled`` is True a direct answer
has already been published (a blocked/preset moderation response or a
matched annotation) and the backend turn must be skipped. Otherwise
``query`` is the possibly moderation-overridden query to send onward.
"""
from core.app.apps.agent_app.app_runner import publish_text_answer
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
app_config = application_generate_entity.app_config
model_name = application_generate_entity.model_conf.model
query = application_generate_entity.query
# content moderation (sensitive_word_avoidance); a blocked input yields a
# preset answer, an "overridden" action returns a sanitized query.
try:
_, _, query = InputModeration().check(
app_id=app_config.app_id,
tenant_id=app_config.tenant_id,
app_config=app_config,
inputs=dict(application_generate_entity.inputs),
query=query or "",
message_id=message.id,
trace_manager=application_generate_entity.trace_manager,
)
except ModerationError as e:
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=str(e))
return True, query
# annotation reply: a matching annotation answers the turn deterministically.
if query:
annotation_reply = AnnotationReplyFeature().query(
app_record=app_model,
message=message,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER,
)
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=annotation_reply.content)
return True, query
return False, query
def _resolve_agent(self, app_model: App) -> tuple[Agent, AgentConfigSnapshot, AgentSoulConfig]:
agent = db.session.scalar(
select(Agent).where(
Agent.app_id == app_model.id,
Agent.scope == AgentScope.ROSTER,
Agent.source == AgentSource.AGENT_APP,
Agent.status == AgentStatus.ACTIVE,
)
)
if agent is None:
raise AgentAppGeneratorError("Agent App has no bound Agent")
return self._resolve_agent_by_id(
tenant_id=app_model.tenant_id, agent_id=agent.id, snapshot_id=agent.active_config_snapshot_id
)
@staticmethod
def _resolve_agent_by_id(
*, tenant_id: str, agent_id: str, snapshot_id: str | None
) -> tuple[Agent, AgentConfigSnapshot, AgentSoulConfig]:
agent = db.session.scalar(select(Agent).where(Agent.id == agent_id, Agent.tenant_id == tenant_id))
if agent is None:
raise AgentAppGeneratorError("Agent not found")
if not snapshot_id:
raise AgentAppGeneratorError("Agent has no published version")
snapshot = db.session.scalar(select(AgentConfigSnapshot).where(AgentConfigSnapshot.id == snapshot_id))
if snapshot is None:
raise AgentAppGeneratorError("Agent published version not found")
agent_soul = AgentSoulConfig.model_validate(snapshot.config_snapshot_dict)
return agent, snapshot, agent_soul
__all__ = ["AgentAppGenerator", "AgentAppGeneratorError"]

View File

@ -0,0 +1,200 @@
"""Agent App runner: drive one conversation turn through the dify-agent backend.
Unlike the legacy ``AgentChatAppRunner`` (which runs an in-process ReAct loop),
this runner delegates to the Agent backend: build the run request from the
Agent Soul + conversation, create the run, consume its event stream, and
republish the assistant answer as chat queue events so the existing
EasyUI chat task pipeline persists the message and streams SSE. The conversation
``session_snapshot`` is saved on success for multi-turn continuity (S3).
"""
from __future__ import annotations
import json
import logging
from typing import Any
from pydantic import JsonValue
from clients.agent_backend import (
AgentBackendError,
AgentBackendInternalEventType,
AgentBackendRunClient,
AgentBackendRunEventAdapter,
AgentBackendRunSucceededInternalEvent,
AgentBackendStreamInternalEvent,
)
from core.app.apps.agent_app.runtime_request_builder import (
AgentAppRuntimeBuildContext,
AgentAppRuntimeRequestBuilder,
)
from core.app.apps.agent_app.session_store import AgentAppRuntimeSessionStore, AgentAppSessionScope
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import DifyRunContext
from core.app.entities.queue_entities import QueueLLMChunkEvent, QueueMessageEndEvent
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
from models.agent_config_entities import AgentSoulConfig
logger = logging.getLogger(__name__)
def publish_text_answer(*, queue_manager: AppQueueManager, model_name: str, answer: str) -> None:
"""Publish a complete assistant answer as one chunk + message-end.
The EasyUI chat task pipeline consumes a QueueLLMChunkEvent stream followed
by a QueueMessageEndEvent; emitting the whole answer as a single chunk lets
both the backend-produced answer and short-circuited answers (moderation /
annotation reply) share the exact same persistence + SSE path.
"""
chunk = LLMResultChunk(
model=model_name,
prompt_messages=[],
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=answer)),
)
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_name,
prompt_messages=[],
message=AssistantPromptMessage(content=answer),
usage=LLMUsage.empty_usage(),
),
),
PublishFrom.APPLICATION_MANAGER,
)
class AgentAppRunner:
"""Runs one Agent App conversation turn against the Agent backend."""
def __init__(
self,
*,
request_builder: AgentAppRuntimeRequestBuilder,
agent_backend_client: AgentBackendRunClient,
event_adapter: AgentBackendRunEventAdapter,
session_store: AgentAppRuntimeSessionStore,
) -> None:
self._request_builder = request_builder
self._agent_backend_client = agent_backend_client
self._event_adapter = event_adapter
self._session_store = session_store
def run(
self,
*,
dify_context: DifyRunContext,
agent_id: str,
agent_config_snapshot_id: str,
agent_soul: AgentSoulConfig,
conversation_id: str,
query: str,
message_id: str,
model_name: str,
queue_manager: AppQueueManager,
) -> None:
scope = AgentAppSessionScope(
tenant_id=dify_context.tenant_id,
app_id=dify_context.app_id,
conversation_id=conversation_id,
agent_id=agent_id,
agent_config_snapshot_id=agent_config_snapshot_id,
)
session_snapshot = self._session_store.load_active_snapshot(scope)
runtime = self._request_builder.build(
AgentAppRuntimeBuildContext(
dify_context=dify_context,
agent_id=agent_id,
agent_config_snapshot_id=agent_config_snapshot_id,
agent_soul=agent_soul,
conversation_id=conversation_id,
user_query=query,
idempotency_key=message_id,
session_snapshot=session_snapshot,
)
)
create_response = self._agent_backend_client.create_run(runtime.request)
terminal = self._consume_stream(create_response.run_id, queue_manager=queue_manager)
if not isinstance(terminal, AgentBackendRunSucceededInternalEvent):
error = getattr(terminal, "error", None) or "Agent backend run did not complete successfully."
raise AgentBackendError(str(error))
answer = self._extract_answer(terminal.output)
self._publish_answer(queue_manager=queue_manager, model_name=model_name, answer=answer)
self._save_session(scope=scope, backend_run_id=terminal.run_id, snapshot=terminal.session_snapshot)
def _consume_stream(self, run_id: str, *, queue_manager: AppQueueManager):
terminal = None
for public_event in self._agent_backend_client.stream_events(run_id):
if queue_manager.is_stopped():
self._cancel_run(run_id)
raise GenerateTaskStoppedError()
for internal_event in self._event_adapter.adapt(public_event):
if queue_manager.is_stopped():
self._cancel_run(run_id)
raise GenerateTaskStoppedError()
if internal_event.type in (
AgentBackendInternalEventType.RUN_STARTED,
AgentBackendInternalEventType.STREAM_EVENT,
):
# Stream deltas are accumulated by the backend into the
# terminal output; token-level forwarding is an S3 refinement.
if isinstance(internal_event, AgentBackendStreamInternalEvent):
continue
continue
terminal = internal_event
break
if terminal is not None:
break
return terminal
def _cancel_run(self, run_id: str) -> None:
try:
self._agent_backend_client.cancel_run(run_id)
except Exception:
logger.warning("Failed to cancel stopped Agent App backend run: run_id=%s", run_id, exc_info=True)
def _publish_answer(self, *, queue_manager: AppQueueManager, model_name: str, answer: str) -> None:
# MVP: emit the full answer as a single chunk + message-end. The chat
# task pipeline streams the chunk over SSE and persists the message.
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=answer)
def _save_session(self, *, scope: AgentAppSessionScope, backend_run_id: str, snapshot: Any) -> None:
try:
self._session_store.save_active_snapshot(scope=scope, backend_run_id=backend_run_id, snapshot=snapshot)
except Exception:
logger.warning(
"Failed to persist Agent App conversation session snapshot: "
"tenant_id=%s app_id=%s conversation_id=%s agent_id=%s",
scope.tenant_id,
scope.app_id,
scope.conversation_id,
scope.agent_id,
exc_info=True,
)
@staticmethod
def _extract_answer(output: JsonValue) -> str:
"""Normalize the backend's terminal output to assistant text.
Free-text Agent Apps return a plain string; if a structured output is
configured the value is a JSON object, which we serialize so the chat
message always has a string body.
"""
if isinstance(output, str):
return output
if isinstance(output, dict):
text = output.get("text")
if isinstance(text, str):
return text
return json.dumps(output, ensure_ascii=False)
return json.dumps(output, ensure_ascii=False)
__all__ = ["AgentAppRunner", "publish_text_answer"]

View File

@ -0,0 +1,15 @@
"""Response converter for the Agent App type.
The Agent App streams the same chatbot response shape as the chat / agent-chat
app types, so it reuses that converter wholesale; kept as a distinct subclass so
the app type owns its converter and can diverge later.
"""
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
class AgentAppGenerateResponseConverter(AgentChatAppGenerateResponseConverter):
pass
__all__ = ["AgentAppGenerateResponseConverter"]

View File

@ -0,0 +1,177 @@
"""Build dify-agent run requests for one Agent App conversation turn.
Mirrors the workflow ``WorkflowAgentRuntimeRequestBuilder`` but for the Agent
App surface: the user prompt is the chat message (no workflow-node job / no
previous-node context), and multi-turn continuity flows through the
conversation-keyed ``session_snapshot`` plus the history layer.
"""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Protocol, cast
from agenton.compositor import CompositorSessionSnapshot
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
from dify_agent.protocol import CreateRunRequest
from clients.agent_backend import (
AgentBackendAgentAppRunInput,
AgentBackendModelConfig,
AgentBackendRunRequestBuilder,
redact_for_agent_backend_log,
)
from core.app.entities.app_invoke_entities import DifyRunContext
from core.workflow.nodes.agent_v2.plugin_tools_builder import (
WorkflowAgentPluginToolsBuilder,
WorkflowAgentPluginToolsBuildError,
)
from models.agent_config_entities import AgentSoulConfig
from models.provider_ids import ModelProviderID
class AgentAppRuntimeRequestBuildError(ValueError):
"""Raised when Agent App state cannot be mapped to a valid run request."""
def __init__(self, error_code: str, message: str) -> None:
self.error_code = error_code
super().__init__(message)
class CredentialsProvider(Protocol):
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: ...
@dataclass(frozen=True, slots=True)
class AgentAppRuntimeBuildContext:
dify_context: DifyRunContext
agent_id: str
agent_config_snapshot_id: str
agent_soul: AgentSoulConfig
conversation_id: str
user_query: str
idempotency_key: str
session_snapshot: CompositorSessionSnapshot | None = None
@dataclass(frozen=True, slots=True)
class AgentAppRuntimeRequest:
request: CreateRunRequest
redacted_request: dict[str, Any]
metadata: dict[str, Any]
class AgentAppRuntimeRequestBuilder:
"""Build dify-agent run requests from Agent App conversation state."""
def __init__(
self,
*,
credentials_provider: CredentialsProvider,
request_builder: AgentBackendRunRequestBuilder | None = None,
plugin_tools_builder: WorkflowAgentPluginToolsBuilder | None = None,
) -> None:
self._credentials_provider = credentials_provider
self._request_builder = request_builder or AgentBackendRunRequestBuilder()
self._plugin_tools_builder = plugin_tools_builder or WorkflowAgentPluginToolsBuilder()
def build(self, context: AgentAppRuntimeBuildContext) -> AgentAppRuntimeRequest:
agent_soul = context.agent_soul
if agent_soul.model is None:
raise AgentAppRuntimeRequestBuildError(
"agent_model_not_configured",
"Agent App requires the Agent Soul model to be configured.",
)
metadata = self._build_metadata(context)
credentials = self._credentials_provider.fetch(agent_soul.model.model_provider, agent_soul.model.model)
try:
tools_layer = self._plugin_tools_builder.build(
tenant_id=context.dify_context.tenant_id,
app_id=context.dify_context.app_id,
user_id=context.dify_context.user_id,
tools=agent_soul.tools,
invoke_from=context.dify_context.invoke_from,
)
except WorkflowAgentPluginToolsBuildError as error:
raise AgentAppRuntimeRequestBuildError(error.error_code, str(error)) from error
if tools_layer is not None:
metadata["agent_tools"] = {
"dify_tool_count": len(tools_layer.tools),
"dify_tool_names": [tool.name or tool.tool_name for tool in tools_layer.tools],
}
request = self._request_builder.build_for_agent_app(
AgentBackendAgentAppRunInput(
model=AgentBackendModelConfig(
plugin_id=self._plugin_daemon_plugin_id(
plugin_id=agent_soul.model.plugin_id,
model_provider=agent_soul.model.model_provider,
),
model_provider=self._plugin_daemon_provider_name(agent_soul.model.model_provider),
model=agent_soul.model.model,
credentials=self._normalize_credentials(credentials),
model_settings=agent_soul.model.model_settings.model_dump(mode="json", exclude_none=True),
),
execution_context=DifyExecutionContextLayerConfig(
tenant_id=context.dify_context.tenant_id,
user_id=context.dify_context.user_id,
app_id=context.dify_context.app_id,
conversation_id=context.conversation_id,
agent_id=context.agent_id,
agent_config_version_id=context.agent_config_snapshot_id,
invoke_from="agent_app",
),
agent_soul_prompt=agent_soul.prompt.system_prompt or None,
user_prompt=context.user_query,
tools=tools_layer,
session_snapshot=context.session_snapshot,
idempotency_key=context.idempotency_key,
metadata=metadata,
)
)
redacted = cast(dict[str, Any], redact_for_agent_backend_log(request))
return AgentAppRuntimeRequest(request=request, redacted_request=redacted, metadata=metadata)
@staticmethod
def _build_metadata(context: AgentAppRuntimeBuildContext) -> dict[str, Any]:
return {
"tenant_id": context.dify_context.tenant_id,
"app_id": context.dify_context.app_id,
"conversation_id": context.conversation_id,
"agent_id": context.agent_id,
"agent_config_snapshot_id": context.agent_config_snapshot_id,
}
@staticmethod
def _plugin_daemon_plugin_id(*, plugin_id: str, model_provider: str) -> str:
"""Return the transport plugin id expected by plugin-daemon headers."""
if plugin_id.count("/") == 1:
return plugin_id
if plugin_id:
return ModelProviderID(plugin_id).plugin_id
return ModelProviderID(model_provider).plugin_id
@staticmethod
def _plugin_daemon_provider_name(model_provider: str) -> str:
"""Return the provider name expected by plugin-daemon dispatch payloads."""
return ModelProviderID(model_provider).provider_name
@staticmethod
def _normalize_credentials(credentials: Mapping[str, Any]) -> dict[str, str | int | float | bool | None]:
normalized: dict[str, str | int | float | bool | None] = {}
for key, value in credentials.items():
if isinstance(value, str | int | float | bool) or value is None:
normalized[key] = value
else:
normalized[key] = str(value)
return normalized
__all__ = [
"AgentAppRuntimeBuildContext",
"AgentAppRuntimeRequest",
"AgentAppRuntimeRequestBuildError",
"AgentAppRuntimeRequestBuilder",
]

View File

@ -0,0 +1,106 @@
"""Conversation-keyed Agent backend session store for the Agent App type.
Shares the unified ``agent_runtime_sessions`` table with the workflow Agent
Node store, but owns rows with ``owner_type = conversation``: one Agent App
conversation maps to one Agent session, so multi-turn chat re-enters the same
``session_snapshot``. Cross-conversation memory (PRD Global / Per app) is a
phase-2 concern and not modeled here.
"""
from __future__ import annotations
from dataclasses import dataclass
from agenton.compositor import CompositorSessionSnapshot
from sqlalchemy import select
from core.db.session_factory import session_factory
from libs.datetime_utils import naive_utc_now
from models.agent import (
AgentRuntimeSession,
AgentRuntimeSessionOwnerType,
AgentRuntimeSessionStatus,
)
@dataclass(frozen=True, slots=True)
class AgentAppSessionScope:
"""Identity of one Agent App conversation session."""
tenant_id: str
app_id: str
conversation_id: str
agent_id: str
agent_config_snapshot_id: str
class AgentAppRuntimeSessionStore:
"""Persists Agent backend session snapshots for Agent App conversations."""
def load_active_snapshot(self, scope: AgentAppSessionScope) -> CompositorSessionSnapshot | None:
with session_factory.create_session() as session:
row = session.scalar(self._active_stmt(scope))
if row is None:
return None
return CompositorSessionSnapshot.model_validate_json(row.session_snapshot)
def save_active_snapshot(
self,
*,
scope: AgentAppSessionScope,
backend_run_id: str,
snapshot: CompositorSessionSnapshot | None,
) -> None:
if snapshot is None:
return
snapshot_json = snapshot.model_dump_json()
with session_factory.create_session() as session:
row = session.scalar(self._scope_stmt(scope))
if row is None:
row = AgentRuntimeSession(
tenant_id=scope.tenant_id,
app_id=scope.app_id,
owner_type=AgentRuntimeSessionOwnerType.CONVERSATION,
agent_id=scope.agent_id,
agent_config_snapshot_id=scope.agent_config_snapshot_id,
conversation_id=scope.conversation_id,
backend_run_id=backend_run_id,
session_snapshot=snapshot_json,
composition_layer_specs="[]",
status=AgentRuntimeSessionStatus.ACTIVE,
)
session.add(row)
else:
row.backend_run_id = backend_run_id
row.session_snapshot = snapshot_json
row.status = AgentRuntimeSessionStatus.ACTIVE
row.cleaned_at = None
session.commit()
def mark_cleaned(self, *, scope: AgentAppSessionScope, backend_run_id: str | None = None) -> None:
with session_factory.create_session() as session:
row = session.scalar(self._active_stmt(scope))
if row is None:
return
if backend_run_id is not None:
row.backend_run_id = backend_run_id
row.status = AgentRuntimeSessionStatus.CLEANED
row.cleaned_at = naive_utc_now()
session.commit()
@staticmethod
def _scope_stmt(scope: AgentAppSessionScope):
return select(AgentRuntimeSession).where(
AgentRuntimeSession.owner_type == AgentRuntimeSessionOwnerType.CONVERSATION,
AgentRuntimeSession.tenant_id == scope.tenant_id,
AgentRuntimeSession.conversation_id == scope.conversation_id,
AgentRuntimeSession.agent_id == scope.agent_id,
AgentRuntimeSession.agent_config_snapshot_id == scope.agent_config_snapshot_id,
)
@classmethod
def _active_stmt(cls, scope: AgentAppSessionScope):
return cls._scope_stmt(scope).where(AgentRuntimeSession.status == AgentRuntimeSessionStatus.ACTIVE)
__all__ = ["AgentAppRuntimeSessionStore", "AgentAppSessionScope"]

View File

@ -112,15 +112,16 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
match sub_stream_response:
case MessageEndStreamResponse():
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
case ErrorStreamResponse():
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
case _:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -134,6 +134,10 @@ class AppQueueManager(ABC):
self._check_for_sqlalchemy_models(event.model_dump())
self._publish(event, pub_from)
def is_stopped(self) -> bool:
"""Return whether the current task has been manually stopped."""
return self._is_stopped()
@abstractmethod
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
@ -209,14 +213,16 @@ class AppQueueManager(ABC):
def _check_for_sqlalchemy_models(self, data: Any):
# from entity to dict or list
if isinstance(data, dict):
for value in data.values():
self._check_for_sqlalchemy_models(value)
elif isinstance(data, list):
for item in data:
self._check_for_sqlalchemy_models(item)
else:
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
)
match data:
case dict():
for value in data.values():
self._check_for_sqlalchemy_models(value)
case list():
for item in data:
self._check_for_sqlalchemy_models(item)
case _:
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances that"
" cause thread safety issues is not allowed."
)

View File

@ -305,26 +305,27 @@ class AppRunner:
text += message.content
elif isinstance(message.content, list):
for content in message.content:
if isinstance(content, str):
text += content
elif isinstance(content, TextPromptMessageContent):
text += content.data
elif isinstance(content, ImagePromptMessageContent):
if message_id and user_id and tenant_id:
try:
self._handle_multimodal_image_content(
content=content,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
queue_manager=queue_manager,
)
except Exception:
_logger.exception("Failed to handle multimodal image output")
else:
_logger.warning("Received multimodal output but missing required parameters")
else:
text += content.data if hasattr(content, "data") else str(content)
match content:
case str():
text += content
case TextPromptMessageContent():
text += content.data
case ImagePromptMessageContent():
if message_id and user_id and tenant_id:
try:
self._handle_multimodal_image_content(
content=content,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
queue_manager=queue_manager,
)
except Exception:
_logger.exception("Failed to handle multimodal image output")
else:
_logger.warning("Received multimodal output but missing required parameters")
case _:
text += content.data if hasattr(content, "data") else str(content)
if not model:
model = result.model

View File

@ -112,15 +112,16 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
match sub_stream_response:
case MessageEndStreamResponse():
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
case ErrorStreamResponse():
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
case _:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -276,17 +276,18 @@ class WorkflowResponseConverter:
created_by: CreatedByDict | dict[str, object] = {}
user = self._user
if isinstance(user, Account):
created_by = AccountCreatedByDict(
id=user.id,
name=user.name,
email=user.email,
)
elif isinstance(user, EndUser):
created_by = EndUserCreatedByDict(
id=user.id,
user=user.session_id,
)
match user:
case Account():
created_by = AccountCreatedByDict(
id=user.id,
name=user.name,
email=user.email,
)
case EndUser():
created_by = EndUserCreatedByDict(
id=user.id,
user=user.session_id,
)
return WorkflowFinishStreamResponse(
task_id=task_id,
@ -455,17 +456,18 @@ class WorkflowResponseConverter:
created_by: Mapping[str, object]
user = creator_user
if isinstance(user, Account):
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
else:
created_by = {
"id": user.id,
"user": user.session_id,
}
match user:
case Account():
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
case _:
created_by = {
"id": user.id,
"user": user.session_id,
}
return WorkflowFinishStreamResponse(
task_id=task_id,
@ -562,15 +564,16 @@ class WorkflowResponseConverter:
outputs, outputs_truncated = self._truncate_mapping(encoded_outputs)
metadata = self._merge_metadata(event.execution_metadata, snapshot)
if isinstance(event, QueueNodeSucceededEvent):
status = WorkflowNodeExecutionStatus.SUCCEEDED
error_message = event.error
elif isinstance(event, QueueNodeFailedEvent):
status = WorkflowNodeExecutionStatus.FAILED
error_message = event.error
else:
status = WorkflowNodeExecutionStatus.EXCEPTION
error_message = event.error
match event:
case QueueNodeSucceededEvent():
status = WorkflowNodeExecutionStatus.SUCCEEDED
error_message = event.error
case QueueNodeFailedEvent():
status = WorkflowNodeExecutionStatus.FAILED
error_message = event.error
case _:
status = WorkflowNodeExecutionStatus.EXCEPTION
error_message = event.error
return NodeFinishStreamResponse(
task_id=task_id,

View File

@ -109,17 +109,18 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
match sub_stream_response:
case MessageEndStreamResponse():
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
case ErrorStreamResponse():
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
case _:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -15,6 +15,7 @@ from core.app.entities.task_entities import (
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
@classmethod
@override
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking full response.
@ -24,6 +25,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[Workflow
return dict(blocking_response.model_dump())
@classmethod
@override
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking simple response.
@ -33,6 +35,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[Workflow
return cls.convert_blocking_full_response(blocking_response)
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -66,6 +69,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[Workflow
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

View File

@ -200,6 +200,21 @@ class AgentChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGe
pass
class AgentAppGenerateEntity(ChatAppGenerateEntity):
"""
Agent App (new Agent app type) Generate Entity.
Subclasses ``ChatAppGenerateEntity`` so it rides the exact same EasyUI chat
pipeline (generator, task pipeline, message cycle) without widening every
accepted-entity union. The answer is produced by the dify-agent backend
rather than an in-process LLM call; ``model_conf`` is synthesized from the
bound Agent Soul model so the chat task pipeline can persist usage.
"""
agent_id: str
agent_config_snapshot_id: str
class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
Advanced Chat Application Generate Entity.

View File

@ -46,13 +46,14 @@ class BasedGenerateTaskPipeline:
e = event.error
err: Exception
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError):
err = e
else:
description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e))
match e:
case InvokeAuthorizationError():
err = InvokeAuthorizationError("Incorrect API key provided")
case InvokeError() | ValueError():
err = e
case _:
description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e))
if not message_id or not session:
return err

Some files were not shown because too many files have changed in this diff Show More