mirror of
https://github.com/langgenius/dify.git
synced 2026-04-16 18:46:14 +08:00
Compare commits
176 Commits
dependabot
...
feat/dropd
| Author | SHA1 | Date | |
|---|---|---|---|
| 389901ded8 | |||
| 9356b30bf2 | |||
| 79d87e6000 | |||
| a13996dba1 | |||
| c3eff6abdc | |||
| c661d5c43a | |||
| b665eaa015 | |||
| b08665e598 | |||
| 883d757392 | |||
| 25df902ec4 | |||
| 5956dd79df | |||
| 70556d9386 | |||
| 9fa50774b4 | |||
| 731414a44f | |||
| d42d08aa57 | |||
| 987b5f4bf4 | |||
| 665978a602 | |||
| 8baa864c35 | |||
| 53a22aa41b | |||
| cf4d7afb9c | |||
| e6b5923ff1 | |||
| 538093855b | |||
| e6b8cbe657 | |||
| af7d5e60b4 | |||
| dbceb3067e | |||
| 425457cb16 | |||
| e5bd18132c | |||
| 2f33867d07 | |||
| fd71c56f16 | |||
| e3c2116501 | |||
| fb17339d89 | |||
| 9fd196642d | |||
| 98897a5379 | |||
| 5542329554 | |||
| 79332c0e5e | |||
| 50a55513d4 | |||
| 3bccdd6c9a | |||
| 76af80e332 | |||
| 7a880ae60c | |||
| 5bc0f9513b | |||
| b77801ece9 | |||
| 7de92c598f | |||
| 693080aa12 | |||
| 25c388d0db | |||
| b1722c8af9 | |||
| b65a5fcd97 | |||
| 1c3cba281a | |||
| 800954f8ce | |||
| f66a3c49c4 | |||
| ef396ac84e | |||
| 7e7b27fdec | |||
| 9c90c1c455 | |||
| b1df52b8ff | |||
| e527b7c5f1 | |||
| 149b9d4c0f | |||
| ef28a63ad3 | |||
| e78558bc06 | |||
| f63d7c4121 | |||
| ef062fb397 | |||
| a2ea7ca039 | |||
| 6876cd787b | |||
| 50a6892c3a | |||
| 1bcc7f78c7 | |||
| 2fd5b76ac1 | |||
| 62f42b3f24 | |||
| 2c58b424a1 | |||
| 381c518b23 | |||
| ebf741114d | |||
| 648dde5e96 | |||
| a3042e6332 | |||
| e5fd3133f4 | |||
| e1bbe57f9c | |||
| d4783e8c14 | |||
| 736880e046 | |||
| bd7a9b5fcf | |||
| 9a47bb2f80 | |||
| d7ad2baf79 | |||
| a951cc996b | |||
| 173e0d6f35 | |||
| 62bb830338 | |||
| f7c6270f74 | |||
| 711fe6ba2c | |||
| fbedb60371 | |||
| 974d2f1627 | |||
| ed401728eb | |||
| fc389a54c5 | |||
| c8b372dba0 | |||
| 2333d75c56 | |||
| 2ef9a8a769 | |||
| 21ab9b9d8c | |||
| 79c1473378 | |||
| 93b8a74351 | |||
| 28185170b0 | |||
| 178883b4cc | |||
| e9f9041b25 | |||
| 175290fa04 | |||
| b0c4d8c541 | |||
| 0f643bca76 | |||
| eeebedcfe8 | |||
| 2f682780fa | |||
| ed83f5369e | |||
| 4ee1bd5f32 | |||
| 1c2bbed405 | |||
| d573fc0e65 | |||
| f8b249e649 | |||
| fbcab757d5 | |||
| c0e998ef6e | |||
| 84f25807db | |||
| 83b242be7b | |||
| a12d740a5d | |||
| 3bbb014dc7 | |||
| f040733e28 | |||
| b0bf7ca486 | |||
| 14d83c8bac | |||
| 8b506dfa42 | |||
| ac2258c2dc | |||
| 3c279edcf2 | |||
| 9ed8a5ed73 | |||
| 3d4ddf4a6f | |||
| 4e0273bb28 | |||
| 7056d2ae99 | |||
| d8fbc00cb9 | |||
| 57c5f0ec87 | |||
| e5bd80c719 | |||
| 25a33a454c | |||
| bd30784b1d | |||
| 28fce0a890 | |||
| e1eb582bea | |||
| 2042ee453b | |||
| 33c4e512f1 | |||
| 253e8a3f98 | |||
| 06b63d65d1 | |||
| 08f3133414 | |||
| d412cddf39 | |||
| 671c5cdd84 | |||
| 554f060092 | |||
| e243e8d8a3 | |||
| 1b935a367f | |||
| 2edd083a71 | |||
| dd50a68bf2 | |||
| e8dd3461e8 | |||
| 8dd4473432 | |||
| b5bbbdd840 | |||
| f0266e13c5 | |||
| ae898652b2 | |||
| c34f67495c | |||
| 815c536e05 | |||
| fc64427ae1 | |||
| 11c518478e | |||
| e823635ce1 | |||
| 98e74c8fde | |||
| 29bfa33d59 | |||
| 3ead0beeb1 | |||
| 2108c44c8b | |||
| b0079e55b4 | |||
| d9f54f8bd7 | |||
| 5a446f8200 | |||
| f4d5e2f43d | |||
| 9121f24181 | |||
| 7dd507af04 | |||
| 3b9aad2ba7 | |||
| ea9f74b581 | |||
| e37aaa482d | |||
| a3170f744c | |||
| ced3780787 | |||
| 6faf26683c | |||
| 8ac9cbf733 | |||
| 098ed34469 | |||
| 6cf4d1002f | |||
| a111d56ea3 | |||
| 8436470fcb | |||
| 17da0e4146 | |||
| ea41e9ab4e | |||
| 5770b5feef | |||
| b5259a3a85 | |||
| 596559efc9 |
79
.agents/skills/e2e-cucumber-playwright/SKILL.md
Normal file
79
.agents/skills/e2e-cucumber-playwright/SKILL.md
Normal file
@ -0,0 +1,79 @@
|
||||
---
|
||||
name: e2e-cucumber-playwright
|
||||
description: Write, update, or review Dify end-to-end tests under `e2e/` that use Cucumber, Gherkin, and Playwright. Use when the task involves `.feature` files, `features/step-definitions/`, `features/support/`, `DifyWorld`, scenario tags, locator/assertion choices, or E2E testing best practices for this repository.
|
||||
---
|
||||
|
||||
# Dify E2E Cucumber + Playwright
|
||||
|
||||
Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) as the canonical guide for local architecture and conventions, then apply Playwright/Cucumber best practices only where they fit the current suite.
|
||||
|
||||
## Scope
|
||||
|
||||
- Use this skill for `.feature` files, Cucumber step definitions, `DifyWorld`, hooks, tags, and E2E review work under `e2e/`.
|
||||
- Do not use this skill for Vitest or React Testing Library work under `web/`; use `frontend-testing` instead.
|
||||
- Do not use this skill for backend test or API review tasks under `api/`.
|
||||
|
||||
## Read Order
|
||||
|
||||
1. Read [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) first.
|
||||
2. Read only the files directly involved in the task:
|
||||
- target `.feature` files under `e2e/features/`
|
||||
- related step files under `e2e/features/step-definitions/`
|
||||
- `e2e/features/support/hooks.ts` and `e2e/features/support/world.ts` when session lifecycle or shared state matters
|
||||
- `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter
|
||||
3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved.
|
||||
4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved.
|
||||
5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern.
|
||||
|
||||
## Local Rules
|
||||
|
||||
- `e2e/` uses Cucumber for scenarios and Playwright as the browser layer.
|
||||
- `DifyWorld` is the per-scenario context object. Type `this` as `DifyWorld` and use `async function`, not arrow functions.
|
||||
- Keep glue organized by capability under `e2e/features/step-definitions/`; use `common/` only for broadly reusable steps.
|
||||
- Browser session behavior comes from `features/support/hooks.ts`:
|
||||
- default: authenticated session with shared storage state
|
||||
- `@unauthenticated`: clean browser context
|
||||
- `@authenticated`: readability/selective-run tag only unless implementation changes
|
||||
- `@fresh`: only for `e2e:full*` flows
|
||||
- Do not import Playwright Test runner patterns that bypass the current Cucumber + `DifyWorld` architecture unless the task is explicitly about changing that architecture.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Rebuild local context.
|
||||
- Inspect the target feature area.
|
||||
- Reuse an existing step when wording and behavior already match.
|
||||
- Add a new step only for a genuinely new user action or assertion.
|
||||
- Keep edits close to the current capability folder unless the step is broadly reusable.
|
||||
2. Write behavior-first scenarios.
|
||||
- Describe user-observable behavior, not DOM mechanics.
|
||||
- Keep each scenario focused on one workflow or outcome.
|
||||
- Keep scenarios independent and re-runnable.
|
||||
3. Write step definitions in the local style.
|
||||
- Keep one step to one user-visible action or one assertion.
|
||||
- Prefer Cucumber Expressions such as `{string}` and `{int}`.
|
||||
- Scope locators to stable containers when the page has repeated elements.
|
||||
- Avoid page-object layers or extra helper abstractions unless repeated complexity clearly justifies them.
|
||||
4. Use Playwright in the local style.
|
||||
- Prefer user-facing locators: `getByRole`, `getByLabel`, `getByPlaceholder`, `getByText`, then `getByTestId` for explicit contracts.
|
||||
- Use web-first `expect(...)` assertions.
|
||||
- Do not use `waitForTimeout`, manual polling, or raw visibility checks when a locator action or retrying assertion already expresses the behavior.
|
||||
5. Validate narrowly.
|
||||
- Run the narrowest tagged scenario or flow that exercises the change.
|
||||
- Run `pnpm -C e2e check`.
|
||||
- Broaden verification only when the change affects hooks, tags, setup, or shared step semantics.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
- Does the scenario describe behavior rather than implementation?
|
||||
- Does it fit the current session model, tags, and `DifyWorld` usage?
|
||||
- Should an existing step be reused instead of adding a new one?
|
||||
- Are locators user-facing and assertions web-first?
|
||||
- Does the change introduce hidden coupling across scenarios, tags, or instance state?
|
||||
- Does it document or implement behavior that differs from the real hooks or configuration?
|
||||
|
||||
Lead findings with correctness, flake risk, and architecture drift.
|
||||
|
||||
## References
|
||||
|
||||
- [`references/playwright-best-practices.md`](references/playwright-best-practices.md)
|
||||
- [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md)
|
||||
@ -0,0 +1,4 @@
|
||||
interface:
|
||||
display_name: "E2E Cucumber + Playwright"
|
||||
short_description: "Write and review Dify E2E scenarios."
|
||||
default_prompt: "Use $e2e-cucumber-playwright to write or review a Dify E2E scenario under e2e/."
|
||||
@ -0,0 +1,93 @@
|
||||
# Cucumber Best Practices For Dify E2E
|
||||
|
||||
Use this reference when writing or reviewing Gherkin scenarios, step definitions, parameter expressions, and step reuse in Dify's `e2e/` suite.
|
||||
|
||||
Official sources:
|
||||
|
||||
- https://cucumber.io/docs/guides/10-minute-tutorial/
|
||||
- https://cucumber.io/docs/cucumber/step-definitions/
|
||||
- https://cucumber.io/docs/cucumber/cucumber-expressions/
|
||||
|
||||
## What Matters Most
|
||||
|
||||
### 1. Treat scenarios as executable specifications
|
||||
|
||||
Cucumber scenarios should describe examples of behavior, not test implementation recipes.
|
||||
|
||||
Apply it like this:
|
||||
|
||||
- write what the user does and what should happen
|
||||
- avoid UI-internal wording such as selector details, DOM structure, or component names
|
||||
- keep language concrete enough that the scenario reads like living documentation
|
||||
|
||||
### 2. Keep scenarios focused
|
||||
|
||||
A scenario should usually prove one workflow or business outcome. If a scenario wanders across several unrelated behaviors, split it.
|
||||
|
||||
In Dify's suite, this means:
|
||||
|
||||
- one capability-focused scenario per feature path
|
||||
- no long setup chains when existing bootstrap or reusable steps already cover them
|
||||
- no hidden dependency on another scenario's side effects
|
||||
|
||||
### 3. Reuse steps, but only when behavior really matches
|
||||
|
||||
Good reuse reduces duplication. Bad reuse hides meaning.
|
||||
|
||||
Prefer reuse when:
|
||||
|
||||
- the user action is genuinely the same
|
||||
- the expected outcome is genuinely the same
|
||||
- the wording stays natural across features
|
||||
|
||||
Write a new step when:
|
||||
|
||||
- the behavior is materially different
|
||||
- reusing the old wording would make the scenario misleading
|
||||
- a supposedly generic step would become an implementation-detail wrapper
|
||||
|
||||
### 4. Prefer Cucumber Expressions
|
||||
|
||||
Use Cucumber Expressions for parameters unless regex is clearly necessary.
|
||||
|
||||
Common examples:
|
||||
|
||||
- `{string}` for labels, names, and visible text
|
||||
- `{int}` for counts
|
||||
- `{float}` for decimal values
|
||||
- `{word}` only when the value is truly a single token
|
||||
|
||||
Keep expressions readable. If a step needs complicated parsing logic, first ask whether the scenario wording should be simpler.
|
||||
|
||||
### 5. Keep step definitions thin and meaningful
|
||||
|
||||
Step definitions are glue between Gherkin and automation, not a second abstraction language.
|
||||
|
||||
For Dify:
|
||||
|
||||
- type `this` as `DifyWorld`
|
||||
- use `async function`
|
||||
- keep each step to one user-visible action or assertion
|
||||
- rely on `DifyWorld` and existing support code for shared context
|
||||
- avoid leaking cross-scenario state
|
||||
|
||||
### 6. Use tags intentionally
|
||||
|
||||
Tags should communicate run scope or session semantics, not become ad hoc metadata.
|
||||
|
||||
In Dify's current suite:
|
||||
|
||||
- capability tags group related scenarios
|
||||
- `@unauthenticated` changes session behavior
|
||||
- `@authenticated` is descriptive/selective, not a behavior switch by itself
|
||||
- `@fresh` belongs to reset/full-install flows only
|
||||
|
||||
If a proposed tag implies behavior, verify that hooks or runner configuration actually implement it.
|
||||
|
||||
## Review Questions
|
||||
|
||||
- Does the scenario read like a real example of product behavior?
|
||||
- Are the steps behavior-oriented instead of implementation-oriented?
|
||||
- Is a reused step still truthful in this feature?
|
||||
- Is a new tag documenting real behavior, or inventing semantics that the suite does not implement?
|
||||
- Would a new reader understand the outcome without opening the step-definition file?
|
||||
@ -0,0 +1,96 @@
|
||||
# Playwright Best Practices For Dify E2E
|
||||
|
||||
Use this reference when writing or reviewing locator, assertion, isolation, or synchronization logic for Dify's Cucumber-based E2E suite.
|
||||
|
||||
Official sources:
|
||||
|
||||
- https://playwright.dev/docs/best-practices
|
||||
- https://playwright.dev/docs/locators
|
||||
- https://playwright.dev/docs/test-assertions
|
||||
- https://playwright.dev/docs/browser-contexts
|
||||
|
||||
## What Matters Most
|
||||
|
||||
### 1. Keep scenarios isolated
|
||||
|
||||
Playwright's model is built around clean browser contexts so one test does not leak into another. In Dify's suite, that principle maps to per-scenario session setup in `features/support/hooks.ts` and `DifyWorld`.
|
||||
|
||||
Apply it like this:
|
||||
|
||||
- do not depend on another scenario having run first
|
||||
- do not persist ad hoc scenario state outside `DifyWorld`
|
||||
- do not couple ordinary scenarios to `@fresh` behavior
|
||||
- when a flow needs special auth/session semantics, express that through the existing tag model or explicit hook changes
|
||||
|
||||
### 2. Prefer user-facing locators
|
||||
|
||||
Playwright recommends built-in locators that reflect what users perceive on the page.
|
||||
|
||||
Preferred order in this repository:
|
||||
|
||||
1. `getByRole`
|
||||
2. `getByLabel`
|
||||
3. `getByPlaceholder`
|
||||
4. `getByText`
|
||||
5. `getByTestId` when an explicit test contract is the most stable option
|
||||
|
||||
Avoid raw CSS/XPath selectors unless no stable user-facing contract exists and adding one is not practical.
|
||||
|
||||
Also remember:
|
||||
|
||||
- repeated content usually needs scoping to a stable container
|
||||
- exact text matching is often too brittle when role/name or label already exists
|
||||
- `getByTestId` is acceptable when semantics are weak but the contract is intentional
|
||||
|
||||
### 3. Use web-first assertions
|
||||
|
||||
Playwright assertions auto-wait and retry. Prefer them over manual state inspection.
|
||||
|
||||
Prefer:
|
||||
|
||||
- `await expect(page).toHaveURL(...)`
|
||||
- `await expect(locator).toBeVisible()`
|
||||
- `await expect(locator).toBeHidden()`
|
||||
- `await expect(locator).toBeEnabled()`
|
||||
- `await expect(locator).toHaveText(...)`
|
||||
|
||||
Avoid:
|
||||
|
||||
- `expect(await locator.isVisible()).toBe(true)`
|
||||
- custom polling loops for DOM state
|
||||
- `waitForTimeout` as synchronization
|
||||
|
||||
If a condition genuinely needs custom retry logic, use Playwright's polling/assertion tools deliberately and keep that choice local and explicit.
|
||||
|
||||
### 4. Let actions wait for actionability
|
||||
|
||||
Locator actions already wait for the element to be actionable. Do not preface every click/fill with extra timing logic unless the action needs a specific visible/ready assertion for clarity.
|
||||
|
||||
Good pattern:
|
||||
|
||||
- assert a meaningful visible state when that is part of the behavior
|
||||
- then click/fill/select via locator APIs
|
||||
|
||||
Bad pattern:
|
||||
|
||||
- stack arbitrary waits before every action
|
||||
- wait on unstable implementation details instead of the visible state the user cares about
|
||||
|
||||
### 5. Match debugging to the current suite
|
||||
|
||||
Playwright's wider ecosystem supports traces and rich debugging tools. Dify's current suite already captures:
|
||||
|
||||
- full-page screenshots
|
||||
- page HTML
|
||||
- console errors
|
||||
- page errors
|
||||
|
||||
Use the existing artifact flow by default. If a task is specifically about improving diagnostics, confirm the change fits the current Cucumber architecture before importing broader Playwright tooling.
|
||||
|
||||
## Review Questions
|
||||
|
||||
- Would this locator survive DOM refactors that do not change user-visible behavior?
|
||||
- Is this assertion using Playwright's retrying semantics?
|
||||
- Is any explicit wait masking a real readiness problem?
|
||||
- Does this code preserve per-scenario isolation?
|
||||
- Is a new abstraction really needed, or does it bypass the existing `DifyWorld` + step-definition model?
|
||||
1
.claude/skills/e2e-cucumber-playwright
Symbolic link
1
.claude/skills/e2e-cucumber-playwright
Symbolic link
@ -0,0 +1 @@
|
||||
../../.agents/skills/e2e-cucumber-playwright
|
||||
@ -7,7 +7,7 @@ cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_publisher,trigger_refresh_executor,retention\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
|
||||
100
.github/dependabot.yml
vendored
100
.github/dependabot.yml
vendored
@ -1,106 +1,6 @@
|
||||
version: 2
|
||||
|
||||
updates:
|
||||
- package-ecosystem: "pip"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 10
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
flask:
|
||||
patterns:
|
||||
- "flask"
|
||||
- "flask-*"
|
||||
- "werkzeug"
|
||||
- "gunicorn"
|
||||
google:
|
||||
patterns:
|
||||
- "google-*"
|
||||
- "googleapis-*"
|
||||
opentelemetry:
|
||||
patterns:
|
||||
- "opentelemetry-*"
|
||||
pydantic:
|
||||
patterns:
|
||||
- "pydantic"
|
||||
- "pydantic-*"
|
||||
llm:
|
||||
patterns:
|
||||
- "langfuse"
|
||||
- "langsmith"
|
||||
- "litellm"
|
||||
- "mlflow*"
|
||||
- "opik"
|
||||
- "weave*"
|
||||
- "arize*"
|
||||
- "tiktoken"
|
||||
- "transformers"
|
||||
database:
|
||||
patterns:
|
||||
- "sqlalchemy"
|
||||
- "psycopg2*"
|
||||
- "psycogreen"
|
||||
- "redis*"
|
||||
- "alembic*"
|
||||
storage:
|
||||
patterns:
|
||||
- "boto3*"
|
||||
- "botocore*"
|
||||
- "azure-*"
|
||||
- "bce-*"
|
||||
- "cos-python-*"
|
||||
- "esdk-obs-*"
|
||||
- "google-cloud-storage"
|
||||
- "opendal"
|
||||
- "oss2"
|
||||
- "supabase*"
|
||||
- "tos*"
|
||||
vdb:
|
||||
patterns:
|
||||
- "alibabacloud*"
|
||||
- "chromadb"
|
||||
- "clickhouse-*"
|
||||
- "clickzetta-*"
|
||||
- "couchbase"
|
||||
- "elasticsearch"
|
||||
- "opensearch-py"
|
||||
- "oracledb"
|
||||
- "pgvect*"
|
||||
- "pymilvus"
|
||||
- "pymochow"
|
||||
- "pyobvector"
|
||||
- "qdrant-client"
|
||||
- "intersystems-*"
|
||||
- "tablestore"
|
||||
- "tcvectordb"
|
||||
- "tidb-vector"
|
||||
- "upstash-*"
|
||||
- "volcengine-*"
|
||||
- "weaviate-*"
|
||||
- "xinference-*"
|
||||
- "mo-vector"
|
||||
- "mysql-connector-*"
|
||||
dev:
|
||||
patterns:
|
||||
- "coverage"
|
||||
- "dotenv-linter"
|
||||
- "faker"
|
||||
- "lxml-stubs"
|
||||
- "basedpyright"
|
||||
- "ruff"
|
||||
- "pytest*"
|
||||
- "types-*"
|
||||
- "boto3-stubs"
|
||||
- "hypothesis"
|
||||
- "pandas-stubs"
|
||||
- "scipy-stubs"
|
||||
- "import-linter"
|
||||
- "celery-types"
|
||||
- "mypy*"
|
||||
- "pyrefly"
|
||||
python-packages:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 10
|
||||
|
||||
8
.github/pull_request_template.md
vendored
8
.github/pull_request_template.md
vendored
@ -18,7 +18,7 @@
|
||||
## Checklist
|
||||
|
||||
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
|
||||
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [x] I've updated the documentation accordingly.
|
||||
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
|
||||
- [ ] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [ ] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [ ] I've updated the documentation accordingly.
|
||||
- [ ] I ran `make lint && make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
|
||||
|
||||
7
.github/workflows/docker-build.yml
vendored
7
.github/workflows/docker-build.yml
vendored
@ -6,14 +6,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- api/Dockerfile
|
||||
- web/docker/**
|
||||
- web/Dockerfile
|
||||
- packages/**
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
- .nvmrc
|
||||
|
||||
concurrency:
|
||||
group: docker-build-${{ github.head_ref || github.run_id }}
|
||||
|
||||
1
.github/workflows/main-ci.yml
vendored
1
.github/workflows/main-ci.yml
vendored
@ -92,6 +92,7 @@ jobs:
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'api/tests/integration_tests/vdb/**'
|
||||
- 'api/providers/vdb/*/tests/**'
|
||||
- '.github/workflows/vdb-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
|
||||
@ -62,7 +62,7 @@ jobs:
|
||||
- name: Render coverage markdown from structured data
|
||||
id: render
|
||||
run: |
|
||||
comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \
|
||||
comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \
|
||||
--base base_report.json \
|
||||
< pr_report.json)"
|
||||
|
||||
|
||||
6
.github/workflows/stale.yml
vendored
6
.github/workflows/stale.yml
vendored
@ -23,8 +23,8 @@ jobs:
|
||||
days-before-issue-stale: 15
|
||||
days-before-issue-close: 3
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
|
||||
stale-pr-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
|
||||
stale-issue-message: "Closed due to inactivity. If you have any questions, you can reopen it."
|
||||
stale-pr-message: "Closed due to inactivity. If you have any questions, you can reopen it."
|
||||
stale-issue-label: 'no-issue-activity'
|
||||
stale-pr-label: 'no-pr-activity'
|
||||
any-of-labels: 'duplicate,question,invalid,wontfix,no-issue-activity,no-pr-activity,enhancement,cant-reproduce,help-wanted'
|
||||
any-of-labels: '🌚 invalid,🙋♂️ question,wont-fix,no-issue-activity,no-pr-activity,💪 enhancement,🤔 cant-reproduce,🙏 help wanted'
|
||||
|
||||
2
.github/workflows/vdb-tests-full.yml
vendored
2
.github/workflows/vdb-tests-full.yml
vendored
@ -89,7 +89,7 @@ jobs:
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
10
.github/workflows/vdb-tests.yml
vendored
10
.github/workflows/vdb-tests.yml
vendored
@ -81,12 +81,12 @@ jobs:
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: |
|
||||
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/vdb/chroma \
|
||||
api/tests/integration_tests/vdb/pgvector \
|
||||
api/tests/integration_tests/vdb/qdrant \
|
||||
api/tests/integration_tests/vdb/weaviate
|
||||
api/providers/vdb/vdb-chroma/tests/integration_tests \
|
||||
api/providers/vdb/vdb-pgvector/tests/integration_tests \
|
||||
api/providers/vdb/vdb-qdrant/tests/integration_tests \
|
||||
api/providers/vdb/vdb-weaviate/tests/integration_tests
|
||||
|
||||
15
.vscode/launch.json.template
vendored
15
.vscode/launch.json.template
vendored
@ -2,21 +2,10 @@
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Flask API",
|
||||
"name": "Python: API (gevent)",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_ENV": "development"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
"--host=0.0.0.0",
|
||||
"--port=5001",
|
||||
"--no-debugger",
|
||||
"--no-reload"
|
||||
],
|
||||
"program": "${workspaceFolder}/api/app.py",
|
||||
"jinja": true,
|
||||
"justMyCode": true,
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
|
||||
@ -33,6 +33,9 @@ TRIGGER_URL=http://localhost:5001
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
# Collaboration mode toggle
|
||||
ENABLE_COLLABORATION_MODE=false
|
||||
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
@ -57,6 +60,9 @@ REDIS_SSL_CERTFILE=
|
||||
REDIS_SSL_KEYFILE=
|
||||
# Path to client private key file for SSL authentication
|
||||
REDIS_DB=0
|
||||
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
|
||||
# Leave empty to preserve current unprefixed behavior.
|
||||
REDIS_KEY_PREFIX=
|
||||
|
||||
# redis Sentinel configuration.
|
||||
REDIS_USE_SENTINEL=false
|
||||
|
||||
@ -69,8 +69,6 @@ ignore = [
|
||||
"FURB152", # math-constant
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"UP045", # non-pep604-annotation-optional
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
@ -84,7 +82,6 @@ ignore = [
|
||||
"SIM102", # collapsible-if
|
||||
"SIM103", # needless-bool
|
||||
"SIM105", # suppressible-exception
|
||||
"SIM107", # return-in-try-except-finally
|
||||
"SIM108", # if-else-block-instead-of-if-exp
|
||||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
@ -93,29 +90,16 @@ ignore = [
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"__init__.py" = [
|
||||
"F401", # unused-import
|
||||
"F811", # redefined-while-unused
|
||||
]
|
||||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"graphon/model_runtime/callbacks/base_callback.py" = ["T201"]
|
||||
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests,
|
||||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
|
||||
]
|
||||
"controllers/console/explore/trial.py" = ["TID251"]
|
||||
"controllers/console/human_input_form.py" = ["TID251"]
|
||||
"controllers/web/human_input_form.py" = ["TID251"]
|
||||
|
||||
[lint.flake8-tidy-imports]
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
|
||||
msg = "Use Pydantic payload/query models instead of reqparse."
|
||||
|
||||
18
api/.vscode/launch.json.example
vendored
18
api/.vscode/launch.json.example
vendored
@ -3,29 +3,21 @@
|
||||
"compounds": [
|
||||
{
|
||||
"name": "Launch Flask and Celery",
|
||||
"configurations": ["Python: Flask", "Python: Celery"]
|
||||
"configurations": ["Python: API (gevent)", "Python: Celery"]
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Flask",
|
||||
"consoleName": "Flask",
|
||||
"name": "Python: API (gevent)",
|
||||
"consoleName": "API",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "${workspaceFolder}/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"envFile": ".env",
|
||||
"module": "flask",
|
||||
"program": "${workspaceFolder}/app.py",
|
||||
"justMyCode": true,
|
||||
"jinja": true,
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
"--port=5001"
|
||||
]
|
||||
"jinja": true
|
||||
},
|
||||
{
|
||||
"name": "Python: Celery",
|
||||
|
||||
@ -21,8 +21,9 @@ RUN apt-get update \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies
|
||||
# Install Python dependencies (workspace members under providers/vdb/)
|
||||
COPY pyproject.toml uv.lock ./
|
||||
COPY providers ./providers
|
||||
RUN uv sync --locked --no-dev
|
||||
|
||||
# production stage
|
||||
|
||||
29
api/app.py
29
api/app.py
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
@ -9,17 +10,35 @@ if TYPE_CHECKING:
|
||||
celery: Celery
|
||||
|
||||
|
||||
HOST = "0.0.0.0"
|
||||
PORT = 5001
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_db_command() -> bool:
|
||||
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def log_startup_banner(host: str, port: int) -> None:
|
||||
debugger_attached = sys.gettrace() is not None
|
||||
logger.info("Serving Dify API via gevent WebSocket server")
|
||||
logger.info("Bound to http://%s:%s", host, port)
|
||||
logger.info("Debugger attached: %s", "on" if debugger_attached else "off")
|
||||
logger.info("Press CTRL+C to quit")
|
||||
|
||||
|
||||
# create app
|
||||
flask_app = None
|
||||
socketio_app = None
|
||||
|
||||
if is_db_command():
|
||||
from app_factory import create_migrations_app
|
||||
|
||||
app = create_migrations_app()
|
||||
socketio_app = app
|
||||
flask_app = app
|
||||
else:
|
||||
# Gunicorn and Celery handle monkey patching automatically in production by
|
||||
# specifying the `gevent` worker class. Manual monkey patching is not required here.
|
||||
@ -30,8 +49,14 @@ else:
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
app = create_app()
|
||||
socketio_app, flask_app = create_app()
|
||||
app = flask_app
|
||||
celery = cast("Celery", app.extensions["celery"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
from gevent import pywsgi
|
||||
from geventwebsocket.handler import WebSocketHandler # type: ignore[reportMissingTypeStubs]
|
||||
|
||||
log_startup_banner(HOST, PORT)
|
||||
server = pywsgi.WSGIServer((HOST, PORT), socketio_app, handler_class=WebSocketHandler)
|
||||
server.serve_forever()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import socketio # type: ignore[reportMissingTypeStubs]
|
||||
from flask import request
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
@ -10,6 +11,7 @@ from contexts.wrapper import RecyclableContextVar
|
||||
from controllers.console.error import UnauthorizedAndForceLogout
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_socketio import sio
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
@ -122,14 +124,18 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
return dify_app
|
||||
|
||||
|
||||
def create_app() -> DifyApp:
|
||||
def create_app() -> tuple[socketio.WSGIApp, DifyApp]:
|
||||
start_time = time.perf_counter()
|
||||
app = create_flask_app_with_configs()
|
||||
initialize_extensions(app)
|
||||
|
||||
sio.app = app
|
||||
socketio_app = socketio.WSGIApp(sio, app)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
|
||||
return app
|
||||
return socketio_app, app
|
||||
|
||||
|
||||
def initialize_extensions(app: DifyApp):
|
||||
|
||||
@ -341,11 +341,10 @@ def add_qdrant_index(field: str):
|
||||
click.echo(click.style("No dataset collection bindings found.", fg="red"))
|
||||
return
|
||||
import qdrant_client
|
||||
from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.models import PayloadSchemaType
|
||||
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
||||
|
||||
for binding in bindings:
|
||||
if dify_config.QDRANT_URL is None:
|
||||
raise ValueError("Qdrant URL is required.")
|
||||
|
||||
@ -1274,6 +1274,13 @@ class PositionConfig(BaseSettings):
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class CollaborationConfig(BaseSettings):
|
||||
ENABLE_COLLABORATION_MODE: bool = Field(
|
||||
description="Whether to enable collaboration mode features across the workspace",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class LoginConfig(BaseSettings):
|
||||
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
|
||||
description="whether to enable email code login",
|
||||
@ -1399,6 +1406,7 @@ class FeatureConfig(
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
CollaborationConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
|
||||
@ -160,6 +160,16 @@ class DatabaseConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
DB_SESSION_TIMEZONE_OVERRIDE: str = Field(
|
||||
description=(
|
||||
"PostgreSQL session timezone override injected via startup options."
|
||||
" Default is 'UTC' for out-of-the-box consistency."
|
||||
" Set to empty string to disable app-level timezone injection, for example when using RDS Proxy"
|
||||
" together with a database-side default timezone."
|
||||
),
|
||||
default="UTC",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
|
||||
@ -227,12 +237,13 @@ class DatabaseConfig(BaseSettings):
|
||||
connect_args: dict[str, str] = {}
|
||||
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
|
||||
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
connect_args = {"options": merged_options}
|
||||
merged_options = options.strip()
|
||||
session_timezone_override = self.DB_SESSION_TIMEZONE_OVERRIDE.strip()
|
||||
if session_timezone_override:
|
||||
timezone_opt = f"-c timezone={session_timezone_override}"
|
||||
merged_options = f"{merged_options} {timezone_opt}".strip() if merged_options else timezone_opt
|
||||
if merged_options:
|
||||
connect_args = {"options": merged_options}
|
||||
|
||||
result: SQLAlchemyEngineOptionsDict = {
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
|
||||
5
api/configs/middleware/cache/redis_config.py
vendored
5
api/configs/middleware/cache/redis_config.py
vendored
@ -32,6 +32,11 @@ class RedisConfig(BaseSettings):
|
||||
default=0,
|
||||
)
|
||||
|
||||
REDIS_KEY_PREFIX: str = Field(
|
||||
description="Optional global prefix for Redis keys, topics, and transport artifacts",
|
||||
default="",
|
||||
)
|
||||
|
||||
REDIS_USE_SSL: bool = Field(
|
||||
description="Enable SSL/TLS for the Redis connection",
|
||||
default=False,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@ -42,17 +41,17 @@ class HologresConfig(BaseSettings):
|
||||
default="public",
|
||||
)
|
||||
|
||||
HOLOGRES_TOKENIZER: TokenizerType = Field(
|
||||
HOLOGRES_TOKENIZER: str = Field(
|
||||
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
|
||||
default="jieba",
|
||||
)
|
||||
|
||||
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
|
||||
HOLOGRES_DISTANCE_METHOD: str = Field(
|
||||
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
|
||||
default="Cosine",
|
||||
)
|
||||
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field(
|
||||
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
|
||||
default="rabitq",
|
||||
)
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
"""Configuration for InterSystems IRIS vector database."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, PositiveInt, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@ -64,7 +66,7 @@ class IrisVectorConfig(BaseSettings):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate IRIS configuration values.
|
||||
|
||||
Args:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@ -23,9 +24,9 @@ class ConversationRenamePayload(BaseModel):
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
conversation_id: UUIDStrOrEmpty = Field(description="Conversation UUID")
|
||||
first_id: UUIDStrOrEmpty | None = Field(default=None, description="First message ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
@ -69,11 +70,35 @@ class WorkflowUpdatePayload(BaseModel):
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
# --- Dataset schemas ---
|
||||
|
||||
|
||||
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
|
||||
|
||||
|
||||
class ChildChunkCreatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkUpdatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class DocumentBatchDownloadZipPayload(BaseModel):
|
||||
"""Request payload for bulk downloading documents as a zip archive."""
|
||||
|
||||
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
|
||||
|
||||
|
||||
class MetadataUpdatePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
# --- Audio schemas ---
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
voice: str | None = Field(default=None, description="Voice to use for TTS")
|
||||
text: str | None = Field(default=None, description="Text to convert to audio")
|
||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||
|
||||
@ -65,6 +65,7 @@ from .app import (
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
workflow_comment,
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
@ -116,6 +117,7 @@ from .explore import (
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
@ -201,6 +203,7 @@ __all__ = [
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
"socketio_workflow",
|
||||
"spec",
|
||||
"statistic",
|
||||
"tags",
|
||||
@ -211,6 +214,7 @@ __all__ = [
|
||||
"website",
|
||||
"workflow",
|
||||
"workflow_app_log",
|
||||
"workflow_comment",
|
||||
"workflow_draft_variable",
|
||||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
|
||||
@ -1,12 +1,16 @@
|
||||
from datetime import datetime
|
||||
|
||||
import flask_restx
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource
|
||||
from flask_restx._http import HTTPStatus
|
||||
from pydantic import field_validator
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
@ -16,21 +20,31 @@ from services.api_token_service import ApiTokenCache
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
api_key_fields = {
|
||||
"id": fields.String,
|
||||
"type": fields.String,
|
||||
"token": fields.String,
|
||||
"last_used_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
|
||||
|
||||
api_key_list_model = console_ns.model(
|
||||
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
|
||||
)
|
||||
class ApiKeyItem(ResponseModel):
|
||||
id: str
|
||||
type: str
|
||||
token: str
|
||||
last_used_at: int | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("last_used_at", "created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class ApiKeyList(ResponseModel):
|
||||
data: list[ApiKeyItem]
|
||||
|
||||
|
||||
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
@ -54,7 +68,6 @@ class BaseApiKeyListResource(Resource):
|
||||
token_prefix: str | None = None
|
||||
max_keys = 10
|
||||
|
||||
@marshal_with(api_key_list_model)
|
||||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
@ -66,9 +79,8 @@ class BaseApiKeyListResource(Resource):
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
)
|
||||
).all()
|
||||
return {"items": keys}
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@marshal_with(api_key_item_model)
|
||||
@edit_permission_required
|
||||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
@ -100,7 +112,7 @@ class BaseApiKeyListResource(Resource):
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
db.session.commit()
|
||||
return api_token, 201
|
||||
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
|
||||
|
||||
|
||||
class BaseApiKeyResource(Resource):
|
||||
@ -147,7 +159,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("get_app_api_keys")
|
||||
@console_ns.doc(description="Get all API keys for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(200, "Success", api_key_list_model)
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for an app"""
|
||||
return super().get(resource_id)
|
||||
@ -155,7 +167,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("create_app_api_key")
|
||||
@console_ns.doc(description="Create a new API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for an app"""
|
||||
@ -187,7 +199,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get all API keys for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Success", api_key_list_model)
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for a dataset"""
|
||||
return super().get(resource_id)
|
||||
@ -195,7 +207,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("create_dataset_api_key")
|
||||
@console_ns.doc(description="Create a new API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for a dataset"""
|
||||
|
||||
@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateArgs, AdvancedPromptTemplateService
|
||||
|
||||
|
||||
class AdvancedPromptTemplateQuery(BaseModel):
|
||||
@ -35,5 +35,10 @@ class AdvancedPromptTemplateList(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
|
||||
prompt_args: AdvancedPromptTemplateArgs = {
|
||||
"app_mode": args.app_mode,
|
||||
"model_mode": args.model_mode,
|
||||
"model_name": args.model_name,
|
||||
"has_context": args.has_context,
|
||||
}
|
||||
return AdvancedPromptTemplateService.get_prompt(prompt_args)
|
||||
|
||||
@ -25,7 +25,13 @@ from fields.annotation_fields import (
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.annotation_service import (
|
||||
AppAnnotationService,
|
||||
EnableAnnotationArgs,
|
||||
UpdateAnnotationArgs,
|
||||
UpdateAnnotationSettingArgs,
|
||||
UpsertAnnotationArgs,
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
@ -120,7 +126,12 @@ class AnnotationReplyActionApi(Resource):
|
||||
args = AnnotationReplyPayload.model_validate(console_ns.payload)
|
||||
match action:
|
||||
case "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||
enable_args: EnableAnnotationArgs = {
|
||||
"score_threshold": args.score_threshold,
|
||||
"embedding_provider_name": args.embedding_provider_name,
|
||||
"embedding_model_name": args.embedding_model_name,
|
||||
}
|
||||
result = AppAnnotationService.enable_app_annotation(enable_args, app_id)
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
return result, 200
|
||||
@ -161,7 +172,8 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
|
||||
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
|
||||
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args)
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -237,8 +249,16 @@ class AnnotationApi(Resource):
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
args = CreateAnnotationPayload.model_validate(console_ns.payload)
|
||||
data = args.model_dump(exclude_none=True)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
|
||||
upsert_args: UpsertAnnotationArgs = {}
|
||||
if args.answer is not None:
|
||||
upsert_args["answer"] = args.answer
|
||||
if args.content is not None:
|
||||
upsert_args["content"] = args.content
|
||||
if args.message_id is not None:
|
||||
upsert_args["message_id"] = args.message_id
|
||||
if args.question is not None:
|
||||
upsert_args["question"] = args.question
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id)
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@ -315,9 +335,12 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
args.model_dump(exclude_none=True), app_id, annotation_id
|
||||
)
|
||||
update_args: UpdateAnnotationArgs = {}
|
||||
if args.answer is not None:
|
||||
update_args["answer"] = args.answer
|
||||
if args.question is not None:
|
||||
update_args["question"] = args.question
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id)
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
|
||||
@ -6,10 +6,9 @@ from typing import Any, Literal
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.helpers import FileInfo
|
||||
@ -31,13 +30,14 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import build_icon_url
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportMode
|
||||
from services.entities.dsl_entities import ImportMode, ImportStatus
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DataSource,
|
||||
InfoList,
|
||||
@ -161,15 +161,6 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
return value
|
||||
|
||||
|
||||
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
|
||||
if icon is None or icon_type is None:
|
||||
return None
|
||||
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
|
||||
if icon_type_value.lower() != IconType.IMAGE:
|
||||
return None
|
||||
return file_helpers.get_signed_file_url(icon)
|
||||
|
||||
|
||||
class Tag(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -292,7 +283,7 @@ class Site(ResponseModel):
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
return build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
@field_validator("icon_type", mode="before")
|
||||
@classmethod
|
||||
@ -342,7 +333,7 @@ class AppPartial(ResponseModel):
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
return build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
@ -390,7 +381,7 @@ class AppDetailWithSite(AppDetail):
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
return build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
|
||||
class AppPagination(ResponseModel):
|
||||
@ -632,7 +623,7 @@ class AppCopyApi(Resource):
|
||||
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
|
||||
result = import_service.import_app(
|
||||
@ -645,6 +636,13 @@ class AppCopyApi(Resource):
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
return result.model_dump(mode="json"), 400
|
||||
if result.status == ImportStatus.PENDING:
|
||||
session.rollback()
|
||||
return result.model_dump(mode="json"), 202
|
||||
session.commit()
|
||||
|
||||
# Inherit web app permission from original app
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -10,35 +11,15 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (
|
||||
app_import_check_dependencies_fields,
|
||||
app_import_fields,
|
||||
leaked_dependency_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_dsl_service import AppDslService, Import
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportStatus
|
||||
from services.entities.dsl_entities import CheckDependenciesResult, ImportStatus
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
leaked_dependency_model = console_ns.model("LeakedDependency", leaked_dependency_fields)
|
||||
|
||||
app_import_model = console_ns.model("AppImport", app_import_fields)
|
||||
|
||||
# For nested models, need to replace nested dict with registered model
|
||||
app_import_check_dependencies_fields_copy = app_import_check_dependencies_fields.copy()
|
||||
app_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(fields.Nested(leaked_dependency_model))
|
||||
app_import_check_dependencies_model = console_ns.model(
|
||||
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppImportPayload(BaseModel):
|
||||
mode: str = Field(..., description="Import mode")
|
||||
@ -52,18 +33,18 @@ class AppImportPayload(BaseModel):
|
||||
app_id: str | None = Field(None)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
register_schema_models(console_ns, AppImportPayload, Import, CheckDependenciesResult)
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports")
|
||||
class AppImportApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
|
||||
@console_ns.response(200, "Import completed", console_ns.models[Import.__name__])
|
||||
@console_ns.response(202, "Import pending confirmation", console_ns.models[Import.__name__])
|
||||
@console_ns.response(400, "Import failed", console_ns.models[Import.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_model)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@ -71,8 +52,9 @@ class AppImportApi(Resource):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = AppImportPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Create service with session
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# AppDslService performs internal commits for some creation paths, so use a plain
|
||||
# Session here instead of nesting it inside sessionmaker(...).begin().
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Import app
|
||||
account = current_user
|
||||
@ -88,6 +70,10 @@ class AppImportApi(Resource):
|
||||
icon_background=args.icon_background,
|
||||
app_id=args.app_id,
|
||||
)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
# update web app setting as private
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
@ -104,21 +90,25 @@ class AppImportApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/imports/<string:import_id>/confirm")
|
||||
class AppImportConfirmApi(Resource):
|
||||
@console_ns.response(200, "Import confirmed", console_ns.models[Import.__name__])
|
||||
@console_ns.response(400, "Import failed", console_ns.models[Import.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_model)
|
||||
@edit_permission_required
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# Create service with session
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
account = current_user
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED:
|
||||
@ -128,14 +118,14 @@ class AppImportConfirmApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/imports/<string:app_id>/check-dependencies")
|
||||
class AppImportCheckDependenciesApi(Resource):
|
||||
@console_ns.response(200, "Dependencies checked", console_ns.models[CheckDependenciesResult.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_check_dependencies_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
result = import_service.check_dependencies(app_model=app_model)
|
||||
|
||||
|
||||
@ -2,20 +2,37 @@ from typing import Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import selectinload
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.raws import FilesContainedField
|
||||
from fields.conversation_fields import (
|
||||
Conversation as ConversationResponse,
|
||||
)
|
||||
from fields.conversation_fields import (
|
||||
ConversationDetail as ConversationDetailResponse,
|
||||
)
|
||||
from fields.conversation_fields import (
|
||||
ConversationMessageDetail as ConversationMessageDetailResponse,
|
||||
)
|
||||
from fields.conversation_fields import (
|
||||
ConversationPagination as ConversationPaginationResponse,
|
||||
)
|
||||
from fields.conversation_fields import (
|
||||
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
|
||||
)
|
||||
from fields.conversation_fields import (
|
||||
ResultResponse,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import AppMode
|
||||
@ -62,267 +79,16 @@ console_ns.schema_model(
|
||||
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model(
|
||||
"SimpleAccount",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"email": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
feedback_stat_model = console_ns.model(
|
||||
"FeedbackStat",
|
||||
{
|
||||
"like": fields.Integer,
|
||||
"dislike": fields.Integer,
|
||||
},
|
||||
)
|
||||
|
||||
status_count_model = console_ns.model(
|
||||
"StatusCount",
|
||||
{
|
||||
"success": fields.Integer,
|
||||
"failed": fields.Integer,
|
||||
"partial_success": fields.Integer,
|
||||
"paused": fields.Integer,
|
||||
},
|
||||
)
|
||||
|
||||
message_file_model = console_ns.model(
|
||||
"MessageFile",
|
||||
{
|
||||
"id": fields.String,
|
||||
"filename": fields.String,
|
||||
"type": fields.String,
|
||||
"url": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"size": fields.Integer,
|
||||
"transfer_method": fields.String,
|
||||
"belongs_to": fields.String(default="user"),
|
||||
"upload_file_id": fields.String(default=None),
|
||||
},
|
||||
)
|
||||
|
||||
agent_thought_model = console_ns.model(
|
||||
"AgentThought",
|
||||
{
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
},
|
||||
)
|
||||
|
||||
simple_model_config_model = console_ns.model(
|
||||
"SimpleModelConfig",
|
||||
{
|
||||
"model": fields.Raw(attribute="model_dict"),
|
||||
"pre_prompt": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
model_config_model = console_ns.model(
|
||||
"ModelConfig",
|
||||
{
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"model": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"pre_prompt": fields.String,
|
||||
"agent_mode": fields.Raw,
|
||||
},
|
||||
)
|
||||
|
||||
# Models that depend on simple_account_model
|
||||
feedback_model = console_ns.model(
|
||||
"Feedback",
|
||||
{
|
||||
"rating": fields.String,
|
||||
"content": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
},
|
||||
)
|
||||
|
||||
annotation_model = console_ns.model(
|
||||
"Annotation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"content": fields.String,
|
||||
"account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
annotation_hit_history_model = console_ns.model(
|
||||
"AnnotationHitHistory",
|
||||
{
|
||||
"annotation_id": fields.String(attribute="id"),
|
||||
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class MessageTextField(fields.Raw):
|
||||
def format(self, value):
|
||||
return value[0]["text"] if value else ""
|
||||
|
||||
|
||||
# Simple message detail model
|
||||
simple_message_detail_model = console_ns.model(
|
||||
"SimpleMessageDetail",
|
||||
{
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": MessageTextField,
|
||||
"answer": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Message detail model that depends on multiple models
|
||||
message_detail_model = console_ns.model(
|
||||
"MessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": fields.Raw,
|
||||
"message_tokens": fields.Integer,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"answer_tokens": fields.Integer,
|
||||
"provider_response_latency": fields.Float,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"feedbacks": fields.List(fields.Nested(feedback_model)),
|
||||
"workflow_run_id": fields.String,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Conversation models
|
||||
conversation_fields_model = console_ns.model(
|
||||
"Conversation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String(),
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"model_config": fields.Nested(simple_model_config_model),
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"message": fields.Nested(simple_message_detail_model, attribute="first_message"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_pagination_model = console_ns.model(
|
||||
"ConversationPagination",
|
||||
{
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_fields_model), attribute="items"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_message_detail_model = console_ns.model(
|
||||
"ConversationMessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"model_config": fields.Nested(model_config_model),
|
||||
"message": fields.Nested(message_detail_model, attribute="first_message"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_with_summary_model = console_ns.model(
|
||||
"ConversationWithSummary",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"name": fields.String,
|
||||
"summary": fields.String(attribute="summary_or_query"),
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"model_config": fields.Nested(simple_model_config_model),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"status_count": fields.Nested(status_count_model),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_with_summary_pagination_model = console_ns.model(
|
||||
"ConversationWithSummaryPagination",
|
||||
{
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_with_summary_model), attribute="items"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_detail_model = console_ns.model(
|
||||
"ConversationDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"introduction": fields.String,
|
||||
"model_config": fields.Nested(model_config_model),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
},
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
CompletionConversationQuery,
|
||||
ChatConversationQuery,
|
||||
ConversationResponse,
|
||||
ConversationPaginationResponse,
|
||||
ConversationMessageDetailResponse,
|
||||
ConversationWithSummaryPaginationResponse,
|
||||
ConversationDetailResponse,
|
||||
ResultResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -332,13 +98,12 @@ class CompletionConversationApi(Resource):
|
||||
@console_ns.doc(description="Get completion conversations with pagination and filtering")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
|
||||
@console_ns.response(200, "Success", conversation_pagination_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[ConversationPaginationResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -394,7 +159,9 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
|
||||
return conversations
|
||||
return ConversationPaginationResponse.model_validate(conversations, from_attributes=True).model_dump(
|
||||
mode="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
|
||||
@ -402,19 +169,19 @@ class CompletionConversationDetailApi(Resource):
|
||||
@console_ns.doc("get_completion_conversation")
|
||||
@console_ns.doc(description="Get completion conversation details with messages")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(200, "Success", conversation_message_detail_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[ConversationMessageDetailResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_message_detail_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id):
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_model, conversation_id)
|
||||
return ConversationMessageDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_completion_conversation")
|
||||
@console_ns.doc(description="Delete a completion conversation")
|
||||
@ -436,7 +203,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-conversations")
|
||||
@ -445,13 +212,12 @@ class ChatConversationApi(Resource):
|
||||
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
|
||||
@console_ns.response(200, "Success", conversation_with_summary_pagination_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[ConversationWithSummaryPaginationResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(conversation_with_summary_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -546,7 +312,9 @@ class ChatConversationApi(Resource):
|
||||
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
|
||||
return conversations
|
||||
return ConversationWithSummaryPaginationResponse.model_validate(conversations, from_attributes=True).model_dump(
|
||||
mode="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
|
||||
@ -554,19 +322,19 @@ class ChatConversationDetailApi(Resource):
|
||||
@console_ns.doc("get_chat_conversation")
|
||||
@console_ns.doc(description="Get chat conversation details")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(200, "Success", conversation_detail_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[ConversationDetailResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(conversation_detail_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id):
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_model, conversation_id)
|
||||
return ConversationDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_chat_conversation")
|
||||
@console_ns.doc(description="Delete a chat conversation")
|
||||
@ -588,7 +356,7 @@ class ChatConversationDetailApi(Resource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
|
||||
@ -1,44 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_variable_fields import (
|
||||
conversation_variable_fields,
|
||||
paginated_conversation_variable_fields,
|
||||
)
|
||||
from fields._value_type_serializer import serialize_value_type
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID to filter variables")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ConversationVariablesQuery.__name__,
|
||||
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
|
||||
|
||||
# For nested models, need to replace nested dict with registered model
|
||||
paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy()
|
||||
paginated_conversation_variable_fields_copy["data"] = fields.List(
|
||||
fields.Nested(conversation_variable_model), attribute="data"
|
||||
)
|
||||
paginated_conversation_variable_model = console_ns.model(
|
||||
"PaginatedConversationVariable", paginated_conversation_variable_fields_copy
|
||||
class ConversationVariableResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
value_type: str
|
||||
value: str | None = None
|
||||
description: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return serialize_value_type(value)
|
||||
except Exception:
|
||||
return serialize_value_type({"value_type": value})
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def _normalize_value(cls, value: Any | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class PaginatedConversationVariableResponse(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[ConversationVariableResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ConversationVariablesQuery,
|
||||
ConversationVariableResponse,
|
||||
PaginatedConversationVariableResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -48,12 +90,15 @@ class ConversationVariablesApi(Resource):
|
||||
@console_ns.doc(description="Get conversation variables for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Conversation variables retrieved successfully",
|
||||
console_ns.models[PaginatedConversationVariableResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
@marshal_with(paginated_conversation_variable_model)
|
||||
def get(self, app_model):
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
@ -72,17 +117,22 @@ class ConversationVariablesApi(Resource):
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
rows = session.scalars(stmt).all()
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"limit": page_size,
|
||||
"total": len(rows),
|
||||
"has_more": False,
|
||||
"data": [
|
||||
{
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
**row.to_variable().model_dump(),
|
||||
}
|
||||
for row in rows
|
||||
],
|
||||
}
|
||||
response = PaginatedConversationVariableResponse.model_validate(
|
||||
{
|
||||
"page": page,
|
||||
"limit": page_size,
|
||||
"total": len(rows),
|
||||
"has_more": False,
|
||||
"data": [
|
||||
ConversationVariableResponse.model_validate(
|
||||
{
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
**row.to_variable().model_dump(),
|
||||
}
|
||||
)
|
||||
for row in rows
|
||||
],
|
||||
}
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@ -1,39 +1,68 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_server_fields
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import AppMCPServerStatus
|
||||
from models.model import AppMCPServer
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_server_model = console_ns.model("AppServer", app_server_fields)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
|
||||
|
||||
|
||||
class MCPServerUpdatePayload(BaseModel):
|
||||
id: str = Field(..., description="Server ID")
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
|
||||
status: str | None = Field(default=None, description="Server status")
|
||||
|
||||
|
||||
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class AppMCPServerResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
server_code: str
|
||||
description: str
|
||||
status: str
|
||||
parameters: dict[str, Any] | list[Any] | str
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def _parse_json_string(cls, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return value
|
||||
return value
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
register_schema_models(console_ns, MCPServerCreatePayload, MCPServerUpdatePayload, AppMCPServerResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/server")
|
||||
@ -41,27 +70,27 @@ class AppMCPServerController(Resource):
|
||||
@console_ns.doc("get_app_mcp_server")
|
||||
@console_ns.doc(description="Get MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "MCP server configuration retrieved successfully", app_server_model)
|
||||
@console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_model)
|
||||
def get(self, app_model):
|
||||
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
|
||||
return server
|
||||
if server is None:
|
||||
return {}
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("create_app_mcp_server")
|
||||
@console_ns.doc(description="Create MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
|
||||
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
|
||||
@console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@login_required
|
||||
@setup_required
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@ -82,20 +111,19 @@ class AppMCPServerController(Resource):
|
||||
)
|
||||
db.session.add(server)
|
||||
db.session.commit()
|
||||
return server
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("update_app_mcp_server")
|
||||
@console_ns.doc(description="Update MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
|
||||
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
|
||||
@console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@get_app_model
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
@ -118,7 +146,7 @@ class AppMCPServerController(Resource):
|
||||
except ValueError:
|
||||
raise ValueError("Invalid status")
|
||||
db.session.commit()
|
||||
return server
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:server_id>/server/refresh")
|
||||
@ -126,13 +154,12 @@ class AppMCPServerRefreshController(Resource):
|
||||
@console_ns.doc("refresh_app_mcp_server")
|
||||
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
|
||||
@console_ns.doc(params={"server_id": "Server ID"})
|
||||
@console_ns.response(200, "MCP server refreshed successfully", app_server_model)
|
||||
@console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def get(self, server_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@ -145,4 +172,4 @@ class AppMCPServerRefreshController(Resource):
|
||||
raise NotFound()
|
||||
server.server_code = AppMCPServer.generate_server_code(16)
|
||||
db.session.commit()
|
||||
return server
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, func, select
|
||||
@ -25,10 +26,21 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from fields.base import ResponseModel
|
||||
from fields.conversation_fields import (
|
||||
AgentThought,
|
||||
ConversationAnnotation,
|
||||
ConversationAnnotationHitHistory,
|
||||
Feedback,
|
||||
JSONValue,
|
||||
MessageFile,
|
||||
format_files_contained,
|
||||
to_timestamp,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
@ -98,6 +110,51 @@ class SuggestedQuestionsResponse(BaseModel):
|
||||
data: list[str] = Field(description="Suggested question")
|
||||
|
||||
|
||||
class MessageDetailResponse(ResponseModel):
|
||||
id: str
|
||||
conversation_id: str
|
||||
inputs: dict[str, JSONValue]
|
||||
query: str
|
||||
message: JSONValue | None = None
|
||||
message_tokens: int | None = None
|
||||
answer: str = Field(validation_alias="re_sign_file_url_answer")
|
||||
answer_tokens: int | None = None
|
||||
provider_response_latency: float | None = None
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
feedbacks: list[Feedback] = Field(default_factory=list)
|
||||
workflow_run_id: str | None = None
|
||||
annotation: ConversationAnnotation | None = None
|
||||
annotation_hit_history: ConversationAnnotationHitHistory | None = None
|
||||
created_at: int | None = None
|
||||
agent_thoughts: list[AgentThought] = Field(default_factory=list)
|
||||
message_files: list[MessageFile] = Field(default_factory=list)
|
||||
extra_contents: list[ExecutionExtraContentDomainModel] = Field(default_factory=list)
|
||||
metadata: JSONValue | None = Field(default=None, validation_alias="message_metadata_dict")
|
||||
status: str
|
||||
error: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
|
||||
return format_files_contained(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class MessageInfiniteScrollPaginationResponse(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[MessageDetailResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ChatMessagesQuery,
|
||||
@ -105,124 +162,8 @@ register_schema_models(
|
||||
FeedbackExportQuery,
|
||||
AnnotationCountResponse,
|
||||
SuggestedQuestionsResponse,
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model(
|
||||
"SimpleAccount",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"email": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
message_file_model = console_ns.model(
|
||||
"MessageFile",
|
||||
{
|
||||
"id": fields.String,
|
||||
"filename": fields.String,
|
||||
"type": fields.String,
|
||||
"url": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"size": fields.Integer,
|
||||
"transfer_method": fields.String,
|
||||
"belongs_to": fields.String(default="user"),
|
||||
"upload_file_id": fields.String(default=None),
|
||||
},
|
||||
)
|
||||
|
||||
agent_thought_model = console_ns.model(
|
||||
"AgentThought",
|
||||
{
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
},
|
||||
)
|
||||
|
||||
# Models that depend on simple_account_model
|
||||
feedback_model = console_ns.model(
|
||||
"Feedback",
|
||||
{
|
||||
"rating": fields.String,
|
||||
"content": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
},
|
||||
)
|
||||
|
||||
annotation_model = console_ns.model(
|
||||
"Annotation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"content": fields.String,
|
||||
"account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
annotation_hit_history_model = console_ns.model(
|
||||
"AnnotationHitHistory",
|
||||
{
|
||||
"annotation_id": fields.String(attribute="id"),
|
||||
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
# Message detail model that depends on multiple models
|
||||
message_detail_model = console_ns.model(
|
||||
"MessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": fields.Raw,
|
||||
"message_tokens": fields.Integer,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"answer_tokens": fields.Integer,
|
||||
"provider_response_latency": fields.Float,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"feedbacks": fields.List(fields.Nested(feedback_model)),
|
||||
"workflow_run_id": fields.String,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"extra_contents": fields.List(fields.Raw),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Message infinite scroll pagination model
|
||||
message_infinite_scroll_pagination_model = console_ns.model(
|
||||
"MessageInfiniteScrollPagination",
|
||||
{
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_detail_model)),
|
||||
},
|
||||
MessageDetailResponse,
|
||||
MessageInfiniteScrollPaginationResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -232,13 +173,12 @@ class ChatMessageListApi(Resource):
|
||||
@console_ns.doc(description="Get chat messages for a conversation with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
|
||||
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[MessageInfiniteScrollPaginationResponse.__name__])
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
@ -298,7 +238,10 @@ class ChatMessageListApi(Resource):
|
||||
history_messages = list(reversed(history_messages))
|
||||
attach_message_extra_contents(history_messages)
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
||||
return MessageInfiniteScrollPaginationResponse.model_validate(
|
||||
InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more),
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
|
||||
@ -468,13 +411,12 @@ class MessageApi(Resource):
|
||||
@console_ns.doc("get_message")
|
||||
@console_ns.doc(description="Get message details by ID")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@console_ns.response(200, "Message retrieved successfully", message_detail_model)
|
||||
@console_ns.response(200, "Message retrieved successfully", console_ns.models[MessageDetailResponse.__name__])
|
||||
@console_ns.response(404, "Message not found")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(message_detail_model)
|
||||
def get(self, app_model, message_id: str):
|
||||
message_id = str(message_id)
|
||||
|
||||
@ -486,4 +428,4 @@ class MessageApi(Resource):
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
attach_message_extra_contents([message])
|
||||
return message
|
||||
return MessageDetailResponse.model_validate(message, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
@ -15,13 +16,11 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Site
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppSiteUpdatePayload(BaseModel):
|
||||
title: str | None = Field(default=None)
|
||||
@ -49,13 +48,26 @@ class AppSiteUpdatePayload(BaseModel):
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AppSiteUpdatePayload.__name__,
|
||||
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
class AppSiteResponse(ResponseModel):
|
||||
app_id: str
|
||||
access_token: str | None = Field(default=None, validation_alias="code")
|
||||
code: str | None = None
|
||||
title: str
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
description: str | None = None
|
||||
default_language: str
|
||||
customize_domain: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
customize_token_strategy: str
|
||||
prompt_public: bool
|
||||
show_workflow_steps: bool
|
||||
use_icon_as_answer_icon: bool
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_site_model = console_ns.model("AppSite", app_site_fields)
|
||||
|
||||
register_schema_models(console_ns, AppSiteUpdatePayload, AppSiteResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site")
|
||||
@ -64,7 +76,7 @@ class AppSite(Resource):
|
||||
@console_ns.doc(description="Update application site configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
|
||||
@console_ns.response(200, "Site configuration updated successfully", console_ns.models[AppSiteResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@ -72,7 +84,6 @@ class AppSite(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -106,7 +117,7 @@ class AppSite(Resource):
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site/access-token-reset")
|
||||
@ -114,7 +125,7 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@console_ns.doc("reset_app_site_access_token")
|
||||
@console_ns.doc(description="Reset access token for application site")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Access token reset successfully", app_site_model)
|
||||
@console_ns.response(200, "Access token reset successfully", console_ns.models[AppSiteResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions (admin/owner required)")
|
||||
@console_ns.response(404, "App or site not found")
|
||||
@setup_required
|
||||
@ -122,7 +133,6 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
@ -135,4 +145,4 @@ class AppSiteAccessTokenReset(Resource):
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@ -4,9 +4,10 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from graphon.enums import NodeType
|
||||
from graphon.file import File
|
||||
from graphon.file import helpers as file_helpers
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
@ -39,6 +40,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.online_user_fields import online_user_list_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -47,6 +49,7 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -57,6 +60,7 @@ _file_access_controller = DatabaseFileAccessController()
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
|
||||
MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS = 50
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@ -150,6 +154,14 @@ class ConvertToWorkflowPayload(BaseModel):
|
||||
icon_background: str | None = None
|
||||
|
||||
|
||||
class WorkflowFeaturesPayload(BaseModel):
|
||||
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
|
||||
|
||||
|
||||
class WorkflowOnlineUsersQuery(BaseModel):
|
||||
app_ids: str = Field(..., description="Comma-separated app IDs")
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
node_id: str
|
||||
|
||||
@ -173,6 +185,8 @@ reg(DefaultBlockConfigQuery)
|
||||
reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(WorkflowFeaturesPayload)
|
||||
reg(WorkflowOnlineUsersQuery)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
|
||||
@ -931,6 +945,32 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/features")
|
||||
class WorkflowFeaturesApi(Resource):
|
||||
"""Update draft workflow features."""
|
||||
|
||||
@console_ns.expect(console_ns.models[WorkflowFeaturesPayload.__name__])
|
||||
@console_ns.doc("update_workflow_features")
|
||||
@console_ns.doc(description="Update draft workflow features")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Workflow features updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {})
|
||||
features = args.features
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
|
||||
@ -942,7 +982,6 @@ class PublishedAllWorkflowApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@ -970,9 +1009,10 @@ class PublishedAllWorkflowApi(Resource):
|
||||
user_id=user_id,
|
||||
named_only=named_only,
|
||||
)
|
||||
serialized_workflows = marshal(workflows, workflow_fields_copy)
|
||||
|
||||
return {
|
||||
"items": workflows,
|
||||
"items": serialized_workflows,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"has_more": has_more,
|
||||
@ -1340,3 +1380,62 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
"status": "error",
|
||||
}
|
||||
), 400
|
||||
|
||||
|
||||
@console_ns.route("/apps/workflows/online-users")
|
||||
class WorkflowOnlineUsersApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__])
|
||||
@console_ns.doc("get_workflow_online_users")
|
||||
@console_ns.doc(description="Get workflow online users")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(online_user_list_fields)
|
||||
def get(self):
|
||||
args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
app_ids = list(dict.fromkeys(app_id.strip() for app_id in args.app_ids.split(",") if app_id.strip()))
|
||||
if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS:
|
||||
raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS} app_ids are allowed per request.")
|
||||
|
||||
if not app_ids:
|
||||
return {"data": []}
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id)
|
||||
|
||||
results = []
|
||||
for app_id in app_ids:
|
||||
if app_id not in accessible_app_ids:
|
||||
continue
|
||||
|
||||
users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}")
|
||||
|
||||
users = []
|
||||
for _, user_info_json in users_json.items():
|
||||
try:
|
||||
user_info = json.loads(user_info_json)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not isinstance(user_info, dict):
|
||||
continue
|
||||
|
||||
avatar = user_info.get("avatar")
|
||||
if isinstance(avatar, str) and avatar and not avatar.startswith(("http://", "https://")):
|
||||
try:
|
||||
user_info["avatar"] = file_helpers.get_signed_file_url(avatar)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to sign workflow online user avatar; using original value. "
|
||||
"app_id=%s avatar=%s error=%s",
|
||||
app_id,
|
||||
avatar,
|
||||
exc,
|
||||
)
|
||||
|
||||
users.append(user_info)
|
||||
results.append({"app_id": app_id, "users": users})
|
||||
|
||||
return {"data": results}
|
||||
|
||||
@ -1,27 +1,26 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import (
|
||||
build_workflow_app_log_pagination_model,
|
||||
build_workflow_archived_log_pagination_model,
|
||||
)
|
||||
from fields.base import ResponseModel
|
||||
from fields.end_user_fields import SimpleEndUser
|
||||
from fields.member_fields import SimpleAccount
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowAppLogQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
|
||||
@ -58,13 +57,113 @@ class WorkflowAppLogQuery(BaseModel):
|
||||
raise ValueError("Invalid boolean value for detail")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
class WorkflowRunForLogResponse(ResponseModel):
|
||||
id: str
|
||||
version: str | None = None
|
||||
status: str | None = None
|
||||
triggered_from: str | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
exceptions_count: int | None = None
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
|
||||
workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns)
|
||||
@field_validator("status", mode="before")
|
||||
@classmethod
|
||||
def _normalize_status(cls, value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class WorkflowRunForArchivedLogResponse(ResponseModel):
|
||||
id: str
|
||||
status: str | None = None
|
||||
triggered_from: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
|
||||
@field_validator("status", mode="before")
|
||||
@classmethod
|
||||
def _normalize_status(cls, value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
|
||||
class WorkflowAppLogPartialResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_run: WorkflowRunForLogResponse | None = None
|
||||
details: Any = None
|
||||
created_from: str | None = None
|
||||
created_by_role: str | None = None
|
||||
created_by_account: SimpleAccount | None = None
|
||||
created_by_end_user: SimpleEndUser | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class WorkflowArchivedLogPartialResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_run: WorkflowRunForArchivedLogResponse | None = None
|
||||
trigger_metadata: Any = None
|
||||
created_by_account: SimpleAccount | None = None
|
||||
created_by_end_user: SimpleEndUser | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class WorkflowAppLogPaginationResponse(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[WorkflowAppLogPartialResponse]
|
||||
|
||||
|
||||
class WorkflowArchivedLogPaginationResponse(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[WorkflowArchivedLogPartialResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
WorkflowAppLogQuery,
|
||||
WorkflowRunForLogResponse,
|
||||
WorkflowRunForArchivedLogResponse,
|
||||
WorkflowAppLogPartialResponse,
|
||||
WorkflowArchivedLogPartialResponse,
|
||||
WorkflowAppLogPaginationResponse,
|
||||
WorkflowArchivedLogPaginationResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
|
||||
@ -73,12 +172,15 @@ class WorkflowAppLogApi(Resource):
|
||||
@console_ns.doc(description="Get workflow application execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow app logs retrieved successfully",
|
||||
console_ns.models[WorkflowAppLogPaginationResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_app_log_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow app logs
|
||||
@ -87,7 +189,7 @@ class WorkflowAppLogApi(Resource):
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
@ -102,7 +204,9 @@ class WorkflowAppLogApi(Resource):
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
return WorkflowAppLogPaginationResponse.model_validate(
|
||||
workflow_app_log_pagination, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-archived-logs")
|
||||
@ -111,12 +215,15 @@ class WorkflowArchivedLogApi(Resource):
|
||||
@console_ns.doc(description="Get workflow archived execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||
@console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow archived logs retrieved successfully",
|
||||
console_ns.models[WorkflowArchivedLogPaginationResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_archived_log_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow archived logs
|
||||
@ -124,7 +231,7 @@ class WorkflowArchivedLogApi(Resource):
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
@ -132,4 +239,6 @@ class WorkflowArchivedLogApi(Resource):
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
return WorkflowArchivedLogPaginationResponse.model_validate(
|
||||
workflow_app_log_pagination, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
335
api/controllers/console/app/workflow_comment.py
Normal file
335
api/controllers/console/app/workflow_comment.py
Normal file
@ -0,0 +1,335 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.member_fields import AccountWithRole
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowCommentCreatePayload(BaseModel):
|
||||
content: str = Field(..., description="Comment content")
|
||||
position_x: float = Field(..., description="Comment X position")
|
||||
position_y: float = Field(..., description="Comment Y position")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Comment content")
|
||||
position_x: float | None = Field(default=None, description="Comment X position")
|
||||
position_y: float | None = Field(default=None, description="Comment Y position")
|
||||
mentioned_user_ids: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Mentioned user IDs. Omit to keep existing mentions.",
|
||||
)
|
||||
|
||||
|
||||
class WorkflowCommentReplyPayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentMentionUsersPayload(BaseModel):
|
||||
users: list[AccountWithRole]
|
||||
|
||||
|
||||
for model in (
|
||||
WorkflowCommentCreatePayload,
|
||||
WorkflowCommentUpdatePayload,
|
||||
WorkflowCommentReplyPayload,
|
||||
):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
register_schema_models(console_ns, AccountWithRole, WorkflowCommentMentionUsersPayload)
|
||||
|
||||
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
|
||||
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
|
||||
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
|
||||
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
|
||||
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
|
||||
workflow_comment_reply_create_model = console_ns.model(
|
||||
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
|
||||
)
|
||||
workflow_comment_reply_update_model = console_ns.model(
|
||||
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments")
|
||||
class WorkflowCommentListApi(Resource):
|
||||
"""API for listing and creating workflow comments."""
|
||||
|
||||
@console_ns.doc("list_workflow_comments")
|
||||
@console_ns.doc(description="Get all comments for a workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_basic_model, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
|
||||
@console_ns.doc("create_workflow_comment")
|
||||
@console_ns.doc(description="Create a new workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_create_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
class WorkflowCommentDetailApi(Resource):
|
||||
"""API for managing individual workflow comments."""
|
||||
|
||||
@console_ns.doc("get_workflow_comment")
|
||||
@console_ns.doc(description="Get a specific workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_detail_model)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
@console_ns.doc("update_workflow_comment")
|
||||
@console_ns.doc(description="Update a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_update_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@console_ns.doc("delete_workflow_comment")
|
||||
@console_ns.doc(description="Delete a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(204, "Comment deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
class WorkflowCommentResolveApi(Resource):
|
||||
"""API for resolving and reopening workflow comments."""
|
||||
|
||||
@console_ns.doc("resolve_workflow_comment")
|
||||
@console_ns.doc(description="Resolve a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_resolve_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
class WorkflowCommentReplyApi(Resource):
|
||||
"""API for managing comment replies."""
|
||||
|
||||
@console_ns.doc("create_workflow_comment_reply")
|
||||
@console_ns.doc(description="Add a reply to a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__])
|
||||
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_create_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_reply(
|
||||
comment_id=comment_id,
|
||||
content=payload.content,
|
||||
created_by=current_user.id,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
|
||||
class WorkflowCommentReplyDetailApi(Resource):
|
||||
"""API for managing individual comment replies."""
|
||||
|
||||
@console_ns.doc("update_workflow_comment_reply")
|
||||
@console_ns.doc(description="Update a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__])
|
||||
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_update_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
reply_id=reply_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
@console_ns.doc("delete_workflow_comment_reply")
|
||||
@console_ns.doc(description="Delete a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.response(204, "Reply deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
reply_id=reply_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
class WorkflowCommentMentionUsersApi(Resource):
|
||||
"""API for getting mentionable users for workflow comments."""
|
||||
|
||||
@console_ns.doc("workflow_comment_mention_users")
|
||||
@console_ns.doc(description="Get all users in current tenant for mentions")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(
|
||||
200, "Mentionable users retrieved successfully", console_ns.models[WorkflowCommentMentionUsersPayload.__name__]
|
||||
)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = WorkflowCommentMentionUsersPayload(users=users)
|
||||
return response.model_dump(mode="json"), 200
|
||||
@ -22,6 +22,7 @@ from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
@ -45,6 +46,16 @@ class WorkflowDraftVariableUpdatePayload(BaseModel):
|
||||
value: Any | None = Field(default=None, description="Variable value")
|
||||
|
||||
|
||||
class ConversationVariableUpdatePayload(BaseModel):
|
||||
conversation_variables: list[dict[str, Any]] = Field(
|
||||
..., description="Conversation variables for the draft workflow"
|
||||
)
|
||||
|
||||
|
||||
class EnvironmentVariableUpdatePayload(BaseModel):
|
||||
environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableListQuery.__name__,
|
||||
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
@ -53,6 +64,14 @@ console_ns.schema_model(
|
||||
WorkflowDraftVariableUpdatePayload.__name__,
|
||||
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ConversationVariableUpdatePayload.__name__,
|
||||
ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EnvironmentVariableUpdatePayload.__name__,
|
||||
EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
@ -510,6 +529,34 @@ class ConversationVariableCollectionApi(Resource):
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
@console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__])
|
||||
@console_ns.doc("update_conversation_variables")
|
||||
@console_ns.doc(description="Update conversation variables for workflow draft")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Conversation variables updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def post(self, app_model: App):
|
||||
payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
conversation_variables_list = payload.conversation_variables
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_conversation_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@ -561,3 +608,31 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
@console_ns.expect(console_ns.models[EnvironmentVariableUpdatePayload.__name__])
|
||||
@console_ns.doc("update_environment_variables")
|
||||
@console_ns.doc(description="Update environment variables for workflow draft")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Environment variables updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
environment_variables_list = payload.environment_variables
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_environment_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -36,7 +36,7 @@ from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowR
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
from services.workflow_run_service import WorkflowRunListArgs, WorkflowRunService
|
||||
|
||||
|
||||
def _build_backstage_input_url(form_token: str | None) -> str | None:
|
||||
@ -214,7 +214,11 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
args: WorkflowRunListArgs = {"limit": args_model.limit}
|
||||
if args_model.last_id is not None:
|
||||
args["last_id"] = args_model.last_id
|
||||
if args_model.status is not None:
|
||||
args["status"] = args_model.status
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
@ -356,7 +360,11 @@ class WorkflowRunListApi(Resource):
|
||||
Get workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
args: WorkflowRunListArgs = {"limit": args_model.limit}
|
||||
if args_model.last_id is not None:
|
||||
args["last_id"] = args_model.last_id
|
||||
if args_model.status is not None:
|
||||
args["status"] = args_model.status
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
|
||||
@ -1,16 +1,17 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.common.schema import register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_user, login_required
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import Account, App, AppMode
|
||||
@ -21,15 +22,6 @@ from ..app.wraps import get_app_model
|
||||
from ..wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields)
|
||||
|
||||
triggers_list_fields_copy = triggers_list_fields.copy()
|
||||
triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model))
|
||||
triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy)
|
||||
|
||||
webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields)
|
||||
|
||||
|
||||
class Parser(BaseModel):
|
||||
@ -41,10 +33,52 @@ class ParserEnable(BaseModel):
|
||||
enable_trigger: bool
|
||||
|
||||
|
||||
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class WorkflowTriggerResponse(ResponseModel):
|
||||
id: str
|
||||
trigger_type: str
|
||||
title: str
|
||||
node_id: str
|
||||
provider_name: str
|
||||
icon: str
|
||||
status: str
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
@field_validator("id", "trigger_type", "title", "node_id", "provider_name", "icon", "status", mode="before")
|
||||
@classmethod
|
||||
def _normalize_string_fields(cls, value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
class WorkflowTriggerListResponse(ResponseModel):
|
||||
data: list[WorkflowTriggerResponse]
|
||||
|
||||
|
||||
class WebhookTriggerResponse(ResponseModel):
|
||||
id: str
|
||||
webhook_id: str
|
||||
webhook_url: str
|
||||
webhook_debug_url: str
|
||||
node_id: str
|
||||
created_at: datetime | None = None
|
||||
|
||||
@field_validator("id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", mode="before")
|
||||
@classmethod
|
||||
def _normalize_string_fields(cls, value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
Parser,
|
||||
ParserEnable,
|
||||
WorkflowTriggerResponse,
|
||||
WorkflowTriggerListResponse,
|
||||
WebhookTriggerResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -57,14 +91,14 @@ class WebhookTriggerApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(webhook_trigger_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__])
|
||||
def get(self, app_model: App):
|
||||
"""Get webhook trigger for a node"""
|
||||
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
node_id = args.node_id
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = session.scalar(
|
||||
select(WorkflowWebhookTrigger)
|
||||
@ -78,7 +112,7 @@ class WebhookTriggerApi(Resource):
|
||||
if not webhook_trigger:
|
||||
raise NotFound("Webhook trigger not found for this node")
|
||||
|
||||
return webhook_trigger
|
||||
return WebhookTriggerResponse.model_validate(webhook_trigger, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/triggers")
|
||||
@ -89,13 +123,13 @@ class AppTriggersApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(triggers_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerListResponse.__name__])
|
||||
def get(self, app_model: App):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
@ -118,7 +152,9 @@ class AppTriggersApi(Resource):
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return {"data": triggers}
|
||||
return WorkflowTriggerListResponse.model_validate({"data": triggers}, from_attributes=True).model_dump(
|
||||
mode="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
|
||||
@ -129,7 +165,7 @@ class AppTriggerEnableApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerResponse.__name__])
|
||||
def post(self, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
args = ParserEnable.model_validate(console_ns.payload)
|
||||
@ -160,4 +196,4 @@ class AppTriggerEnableApi(Resource):
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return trigger
|
||||
return WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
@ -11,8 +14,6 @@ from libs.helper import EmailStr, timezone
|
||||
from models import AccountStatus
|
||||
from services.account_service import RegisterService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ActivateCheckQuery(BaseModel):
|
||||
workspace_id: str | None = Field(default=None)
|
||||
@ -39,8 +40,16 @@ class ActivatePayload(BaseModel):
|
||||
return timezone(value)
|
||||
|
||||
|
||||
for model in (ActivateCheckQuery, ActivatePayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class ActivationCheckResponse(BaseModel):
|
||||
is_valid: bool = Field(description="Whether token is valid")
|
||||
data: dict[str, Any] | None = Field(default=None, description="Activation data if valid")
|
||||
|
||||
|
||||
class ActivationResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse)
|
||||
|
||||
|
||||
@console_ns.route("/activate/check")
|
||||
@ -51,13 +60,7 @@ class ActivateCheckApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"ActivationCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether token is valid"),
|
||||
"data": fields.Raw(description="Activation data if valid"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ActivationCheckResponse.__name__],
|
||||
)
|
||||
def get(self):
|
||||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
@ -95,12 +98,7 @@ class ActivateApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Account activated successfully",
|
||||
console_ns.model(
|
||||
"ActivationResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ActivationResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Already activated or invalid token")
|
||||
def post(self):
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import logging
|
||||
|
||||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@ -42,12 +45,13 @@ from libs.token import (
|
||||
)
|
||||
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.entities.auth_entities import LoginPayloadBase
|
||||
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
|
||||
from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginPayload(LoginPayloadBase):
|
||||
@ -91,10 +95,12 @@ class LoginApi(Resource):
|
||||
normalized_email = request_email.lower()
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
|
||||
if is_login_error_rate_limit:
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.LOGIN_RATE_LIMITED)
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
invite_token = args.invite_token
|
||||
@ -110,14 +116,20 @@ class LoginApi(Resource):
|
||||
invitee_email = data.get("email") if data else None
|
||||
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
|
||||
if invitee_email_normalized != normalized_email:
|
||||
_log_console_login_failure(
|
||||
email=normalized_email,
|
||||
reason=LoginFailureReason.INVALID_INVITATION_EMAIL,
|
||||
)
|
||||
raise InvalidEmailError()
|
||||
account = _authenticate_account_with_case_fallback(
|
||||
request_email, normalized_email, args.password, invite_token
|
||||
)
|
||||
except services.errors.account.AccountLoginError:
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError as exc:
|
||||
AccountService.add_login_error_rate_limit(normalized_email)
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
|
||||
raise AuthenticationFailedError() from exc
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@ -240,20 +252,27 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args.token)
|
||||
if token_data is None:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != user_email:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args.code:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
|
||||
raise EmailCodeError()
|
||||
|
||||
AccountService.revoke_email_code_login_token(args.token)
|
||||
try:
|
||||
account = _get_account_with_case_fallback(original_email)
|
||||
except Unauthorized as exc:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError() from exc
|
||||
except AccountRegisterError:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@ -279,6 +298,7 @@ class EmailCodeLoginApi(Resource):
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
except AccountRegisterError:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
@ -336,3 +356,12 @@ def _authenticate_account_with_case_fallback(
|
||||
if original_email == normalized_email:
|
||||
raise
|
||||
return AccountService.authenticate(normalized_email, password, invite_token)
|
||||
|
||||
|
||||
def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None:
|
||||
logger.warning(
|
||||
"Console login failed: email=%s reason=%s ip_address=%s",
|
||||
email,
|
||||
reason,
|
||||
extract_remote_ip(request),
|
||||
)
|
||||
|
||||
@ -11,10 +11,7 @@ import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.apikey import (
|
||||
api_key_item_model,
|
||||
api_key_list_model,
|
||||
)
|
||||
from controllers.console.apikey import ApiKeyItem, ApiKeyList
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
||||
from controllers.console.wraps import (
|
||||
@ -785,23 +782,23 @@ class DatasetApiKeyApi(Resource):
|
||||
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get dataset API keys")
|
||||
@console_ns.response(200, "API keys retrieved successfully", api_key_list_model)
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_list_model)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||
).all()
|
||||
return {"items": keys}
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@console_ns.response(200, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_item_model)
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -828,7 +825,7 @@ class DatasetApiKeyApi(Resource):
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
db.session.commit()
|
||||
return api_token, 200
|
||||
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")
|
||||
|
||||
@ -4,7 +4,6 @@ from argparse import ArgumentTypeError
|
||||
from collections.abc import Sequence
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request, send_file
|
||||
@ -16,6 +15,7 @@ from sqlalchemy import asc, desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from core.errors.error import (
|
||||
@ -71,9 +71,6 @@ from ..wraps import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# NOTE: Keep constants near the top of the module for discoverability.
|
||||
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_model = get_or_create_model("Dataset", dataset_fields)
|
||||
@ -110,12 +107,6 @@ class GenerateSummaryPayload(BaseModel):
|
||||
document_list: list[str]
|
||||
|
||||
|
||||
class DocumentBatchDownloadZipPayload(BaseModel):
|
||||
"""Request payload for bulk downloading documents as a zip archive."""
|
||||
|
||||
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
|
||||
|
||||
|
||||
class DocumentDatasetListParam(BaseModel):
|
||||
page: int = Field(1, title="Page", description="Page number.")
|
||||
limit: int = Field(20, title="Limit", description="Page size.")
|
||||
@ -1035,7 +1026,7 @@ class DocumentMetadataApi(DocumentResource):
|
||||
|
||||
if not isinstance(doc_metadata, dict):
|
||||
raise ValueError("doc_metadata must be a dictionary.")
|
||||
metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
|
||||
metadata_schema: dict[str, Any] = cast(dict[str, Any], DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
|
||||
|
||||
document.doc_metadata = {}
|
||||
if doc_type == "others":
|
||||
|
||||
@ -10,6 +10,7 @@ from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
@ -82,14 +83,6 @@ class BatchImportPayload(BaseModel):
|
||||
upload_file_id: str
|
||||
|
||||
|
||||
class ChildChunkCreatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkUpdatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkBatchUpdatePayload(BaseModel):
|
||||
chunks: list[ChildChunkUpdateArgs]
|
||||
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
from flask_restx import Resource, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from fields.hit_testing_fields import (
|
||||
child_chunk_fields,
|
||||
document_fields,
|
||||
files_fields,
|
||||
hit_testing_record_fields,
|
||||
segment_fields,
|
||||
)
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
@ -18,39 +18,92 @@ from ..wraps import (
|
||||
setup_required,
|
||||
)
|
||||
|
||||
register_schema_model(console_ns, HitTestingPayload)
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
class HitTestingDocument(ResponseModel):
|
||||
id: str | None = None
|
||||
data_source_type: str | None = None
|
||||
name: str | None = None
|
||||
doc_type: str | None = None
|
||||
doc_metadata: Any | None = None
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
document_model = _get_or_create_model("HitTestingDocument", document_fields)
|
||||
class HitTestingSegment(ResponseModel):
|
||||
id: str | None = None
|
||||
position: int | None = None
|
||||
document_id: str | None = None
|
||||
content: str | None = None
|
||||
sign_content: str | None = None
|
||||
answer: str | None = None
|
||||
word_count: int | None = None
|
||||
tokens: int | None = None
|
||||
keywords: list[str] = Field(default_factory=list)
|
||||
index_node_id: str | None = None
|
||||
index_node_hash: str | None = None
|
||||
hit_count: int | None = None
|
||||
enabled: bool | None = None
|
||||
disabled_at: int | None = None
|
||||
disabled_by: str | None = None
|
||||
status: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
indexing_at: int | None = None
|
||||
completed_at: int | None = None
|
||||
error: str | None = None
|
||||
stopped_at: int | None = None
|
||||
document: HitTestingDocument | None = None
|
||||
|
||||
segment_fields_copy = segment_fields.copy()
|
||||
segment_fields_copy["document"] = fields.Nested(document_model)
|
||||
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
|
||||
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
|
||||
files_model = _get_or_create_model("HitTestingFile", files_fields)
|
||||
|
||||
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
|
||||
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
|
||||
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
|
||||
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
|
||||
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
|
||||
class HitTestingChildChunk(ResponseModel):
|
||||
id: str | None = None
|
||||
content: str | None = None
|
||||
position: int | None = None
|
||||
score: float | None = None
|
||||
|
||||
# Response model for hit testing API
|
||||
hit_testing_response_fields = {
|
||||
"query": fields.String,
|
||||
"records": fields.List(fields.Nested(hit_testing_record_model)),
|
||||
}
|
||||
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
|
||||
|
||||
class HitTestingFile(ResponseModel):
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
size: int | None = None
|
||||
extension: str | None = None
|
||||
mime_type: str | None = None
|
||||
source_url: str | None = None
|
||||
|
||||
|
||||
class HitTestingRecord(ResponseModel):
|
||||
segment: HitTestingSegment | None = None
|
||||
child_chunks: list[HitTestingChildChunk] = Field(default_factory=list)
|
||||
score: float | None = None
|
||||
tsne_position: Any | None = None
|
||||
files: list[HitTestingFile] = Field(default_factory=list)
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class HitTestingResponse(ResponseModel):
|
||||
query: str
|
||||
records: list[HitTestingRecord] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
HitTestingPayload,
|
||||
HitTestingDocument,
|
||||
HitTestingSegment,
|
||||
HitTestingChildChunk,
|
||||
HitTestingFile,
|
||||
HitTestingRecord,
|
||||
HitTestingResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
@ -59,7 +112,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
|
||||
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Hit testing completed successfully",
|
||||
model=console_ns.models[HitTestingResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
@ -74,4 +131,4 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json")
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
@ -18,11 +18,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class MetadataUpdatePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
|
||||
)
|
||||
|
||||
@ -1,21 +1,24 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import and_, select
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_database import db
|
||||
from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields
|
||||
from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from models.model import IconType
|
||||
from services.account_service import TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
@ -36,22 +39,97 @@ class InstalledAppsListQuery(BaseModel):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
app_model = get_or_create_model("InstalledAppInfo", app_fields)
|
||||
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
|
||||
if icon is None or icon_type is None:
|
||||
return None
|
||||
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
|
||||
if icon_type_value.lower() != IconType.IMAGE:
|
||||
return None
|
||||
return file_helpers.get_signed_file_url(icon)
|
||||
|
||||
installed_app_fields_copy = installed_app_fields.copy()
|
||||
installed_app_fields_copy["app"] = fields.Nested(app_model)
|
||||
installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy)
|
||||
|
||||
installed_app_list_fields_copy = installed_app_list_fields.copy()
|
||||
installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model))
|
||||
installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy)
|
||||
def _safe_primitive(value: Any) -> Any:
|
||||
if value is None or isinstance(value, (str, int, float, bool, datetime)):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
class InstalledAppInfoResponse(ResponseModel):
|
||||
id: str
|
||||
name: str | None = None
|
||||
mode: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
use_icon_as_answer_icon: bool | None = None
|
||||
|
||||
@field_validator("mode", "icon_type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_like(cls, value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
|
||||
class InstalledAppResponse(ResponseModel):
|
||||
id: str
|
||||
app: InstalledAppInfoResponse
|
||||
app_owner_tenant_id: str
|
||||
is_pinned: bool
|
||||
last_used_at: int | None = None
|
||||
editable: bool
|
||||
uninstallable: bool
|
||||
|
||||
@field_validator("app", mode="before")
|
||||
@classmethod
|
||||
def _normalize_app(cls, value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
return {
|
||||
"id": _safe_primitive(getattr(value, "id", "")) or "",
|
||||
"name": _safe_primitive(getattr(value, "name", None)),
|
||||
"mode": _safe_primitive(getattr(value, "mode", None)),
|
||||
"icon_type": _safe_primitive(getattr(value, "icon_type", None)),
|
||||
"icon": _safe_primitive(getattr(value, "icon", None)),
|
||||
"icon_background": _safe_primitive(getattr(value, "icon_background", None)),
|
||||
"use_icon_as_answer_icon": _safe_primitive(getattr(value, "use_icon_as_answer_icon", None)),
|
||||
}
|
||||
|
||||
@field_validator("last_used_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class InstalledAppListResponse(ResponseModel):
|
||||
installed_apps: list[InstalledAppResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
InstalledAppCreatePayload,
|
||||
InstalledAppUpdatePayload,
|
||||
InstalledAppsListQuery,
|
||||
InstalledAppInfoResponse,
|
||||
InstalledAppResponse,
|
||||
InstalledAppListResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps")
|
||||
class InstalledAppsListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(installed_app_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
|
||||
def get(self):
|
||||
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@ -125,7 +203,9 @@ class InstalledAppsListApi(Resource):
|
||||
)
|
||||
)
|
||||
|
||||
return {"installed_apps": installed_app_list}
|
||||
return InstalledAppListResponse.model_validate(
|
||||
{"installed_apps": installed_app_list}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -1,66 +1,83 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import AppIconUrlField
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import build_icon_url
|
||||
from libs.login import current_user, login_required
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
app_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"mode": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_type": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"icon_background": fields.String,
|
||||
}
|
||||
|
||||
app_model = get_or_create_model("RecommendedAppInfo", app_fields)
|
||||
|
||||
recommended_app_fields = {
|
||||
"app": fields.Nested(app_model, attribute="app"),
|
||||
"app_id": fields.String,
|
||||
"description": fields.String(attribute="description"),
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"category": fields.String,
|
||||
"position": fields.Integer,
|
||||
"is_listed": fields.Boolean,
|
||||
"can_trial": fields.Boolean,
|
||||
}
|
||||
|
||||
recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields)
|
||||
|
||||
recommended_app_list_fields = {
|
||||
"recommended_apps": fields.List(fields.Nested(recommended_app_model)),
|
||||
"categories": fields.List(fields.String),
|
||||
}
|
||||
|
||||
recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields)
|
||||
|
||||
|
||||
class RecommendedAppsQuery(BaseModel):
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
RecommendedAppsQuery.__name__,
|
||||
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
class RecommendedAppInfoResponse(ResponseModel):
|
||||
id: str
|
||||
name: str | None = None
|
||||
mode: str | None = None
|
||||
icon: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon_background: str | None = None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_enum_like(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
@field_validator("mode", "icon_type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> str | None:
|
||||
return cls._normalize_enum_like(value)
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
|
||||
class RecommendedAppResponse(ResponseModel):
|
||||
app: RecommendedAppInfoResponse | None = None
|
||||
app_id: str
|
||||
description: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
category: str | None = None
|
||||
position: int | None = None
|
||||
is_listed: bool | None = None
|
||||
can_trial: bool | None = None
|
||||
|
||||
|
||||
class RecommendedAppListResponse(ResponseModel):
|
||||
recommended_apps: list[RecommendedAppResponse]
|
||||
categories: list[str]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
RecommendedAppsQuery,
|
||||
RecommendedAppInfoResponse,
|
||||
RecommendedAppResponse,
|
||||
RecommendedAppListResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps")
|
||||
class RecommendedAppListApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(recommended_app_list_model)
|
||||
def get(self):
|
||||
# language args
|
||||
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
@ -72,7 +89,10 @@ class RecommendedAppListApi(Resource):
|
||||
else:
|
||||
language_prefix = languages[0]
|
||||
|
||||
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
|
||||
return RecommendedAppListResponse.model_validate(
|
||||
RecommendedAppService.get_recommended_apps_and_categories(language_prefix),
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps/<uuid:app_id>")
|
||||
|
||||
@ -169,6 +169,7 @@ console_ns.schema_model(
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
|
||||
def post(self, trial_app):
|
||||
"""
|
||||
@ -210,6 +211,7 @@ class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
@ -290,7 +292,6 @@ class TrialChatApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialMessageSuggestedQuestionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def get(self, trial_app, message_id):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
@ -470,7 +471,6 @@ class TrialCompletionApi(TrialAppResource):
|
||||
class TrialSitApi(Resource):
|
||||
"""Resource for trial app sites."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial(None)
|
||||
def get(self, app_model):
|
||||
"""Retrieve app site info.
|
||||
@ -492,7 +492,6 @@ class TrialSitApi(Resource):
|
||||
class TrialAppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial(None)
|
||||
def get(self, app_model):
|
||||
"""Retrieve app parameters."""
|
||||
@ -521,7 +520,6 @@ class TrialAppParameterApi(Resource):
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial(None)
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
def get(self, app_model):
|
||||
@ -534,7 +532,6 @@ class AppApi(Resource):
|
||||
|
||||
|
||||
class AppWorkflowApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial(None)
|
||||
@marshal_with(workflow_model)
|
||||
def get(self, app_model):
|
||||
@ -547,7 +544,6 @@ class AppWorkflowApi(Resource):
|
||||
|
||||
|
||||
class DatasetListApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial(None)
|
||||
def get(self, app_model):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
|
||||
@ -1,15 +1,18 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
||||
from ..common.schema import register_schema_models
|
||||
from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, setup_required
|
||||
|
||||
@ -24,12 +27,52 @@ class APIBasedExtensionPayload(BaseModel):
|
||||
api_key: str = Field(description="API key for authentication")
|
||||
|
||||
|
||||
register_schema_models(console_ns, APIBasedExtensionPayload)
|
||||
class CodeBasedExtensionResponse(ResponseModel):
|
||||
module: str = Field(description="Module name")
|
||||
data: Any = Field(description="Extension data")
|
||||
|
||||
|
||||
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
|
||||
def _mask_api_key(api_key: str) -> str:
|
||||
if not api_key:
|
||||
return api_key
|
||||
if len(api_key) <= 8:
|
||||
return api_key[0] + "******" + api_key[-1]
|
||||
return api_key[:3] + "******" + api_key[-3:]
|
||||
|
||||
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class APIBasedExtensionResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
api_endpoint: str
|
||||
api_key: str
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def _normalize_api_key(cls, value: str) -> str:
|
||||
return _mask_api_key(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
register_schema_models(console_ns, APIBasedExtensionPayload, CodeBasedExtensionResponse, APIBasedExtensionResponse)
|
||||
console_ns.schema_model(
|
||||
"APIBasedExtensionListResponse",
|
||||
TypeAdapter(list[APIBasedExtensionResponse]).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def _serialize_api_based_extension(extension: APIBasedExtension) -> dict[str, Any]:
|
||||
return APIBasedExtensionResponse.model_validate(extension, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/code-based-extension")
|
||||
@ -40,10 +83,7 @@ class CodeBasedExtensionAPI(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"CodeBasedExtensionResponse",
|
||||
{"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")},
|
||||
),
|
||||
console_ns.models[CodeBasedExtensionResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -51,30 +91,34 @@ class CodeBasedExtensionAPI(Resource):
|
||||
def get(self):
|
||||
query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)}
|
||||
return CodeBasedExtensionResponse(
|
||||
module=query.module,
|
||||
data=CodeBasedExtensionService.get_code_based_extension(query.module),
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/api-based-extension")
|
||||
class APIBasedExtensionAPI(Resource):
|
||||
@console_ns.doc("get_api_based_extensions")
|
||||
@console_ns.doc(description="Get all API-based extensions for current tenant")
|
||||
@console_ns.response(200, "Success", api_based_extension_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models["APIBasedExtensionListResponse"])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_model)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
return [
|
||||
_serialize_api_based_extension(extension)
|
||||
for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
]
|
||||
|
||||
@console_ns.doc("create_api_based_extension")
|
||||
@console_ns.doc(description="Create a new API-based extension")
|
||||
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
|
||||
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
|
||||
@console_ns.response(201, "Extension created successfully", console_ns.models[APIBasedExtensionResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_model)
|
||||
def post(self):
|
||||
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@ -86,7 +130,7 @@ class APIBasedExtensionAPI(Resource):
|
||||
api_key=payload.api_key,
|
||||
)
|
||||
|
||||
return APIBasedExtensionService.save(extension_data)
|
||||
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data))
|
||||
|
||||
|
||||
@console_ns.route("/api-based-extension/<uuid:id>")
|
||||
@ -94,26 +138,26 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
@console_ns.doc("get_api_based_extension")
|
||||
@console_ns.doc(description="Get API-based extension by ID")
|
||||
@console_ns.doc(params={"id": "Extension ID"})
|
||||
@console_ns.response(200, "Success", api_based_extension_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[APIBasedExtensionResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_model)
|
||||
def get(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
return _serialize_api_based_extension(
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
)
|
||||
|
||||
@console_ns.doc("update_api_based_extension")
|
||||
@console_ns.doc(description="Update API-based extension")
|
||||
@console_ns.doc(params={"id": "Extension ID"})
|
||||
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
|
||||
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
|
||||
@console_ns.response(200, "Extension updated successfully", console_ns.models[APIBasedExtensionResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_model)
|
||||
def post(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@ -128,7 +172,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
if payload.api_key != HIDDEN_VALUE:
|
||||
extension_data_from_db.api_key = payload.api_key
|
||||
|
||||
return APIBasedExtensionService.save(extension_data_from_db)
|
||||
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data_from_db))
|
||||
|
||||
@console_ns.doc("delete_api_based_extension")
|
||||
@console_ns.doc(description="Delete API-based extension")
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import TypedDict
|
||||
|
||||
from flask import request
|
||||
@ -13,6 +14,14 @@ from services.billing_service import BillingService
|
||||
_FALLBACK_LANG = "en-US"
|
||||
|
||||
|
||||
class NotificationLangContent(TypedDict, total=False):
|
||||
lang: str
|
||||
title: str
|
||||
subtitle: str
|
||||
body: str
|
||||
titlePicUrl: str
|
||||
|
||||
|
||||
class NotificationItemDict(TypedDict):
|
||||
notification_id: str | None
|
||||
frequency: str | None
|
||||
@ -28,9 +37,11 @@ class NotificationResponseDict(TypedDict):
|
||||
notifications: list[NotificationItemDict]
|
||||
|
||||
|
||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||
def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent:
|
||||
"""Return the single LangContent for *lang*, falling back to English."""
|
||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||
return (
|
||||
contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent())
|
||||
)
|
||||
|
||||
|
||||
class DismissNotificationPayload(BaseModel):
|
||||
@ -71,7 +82,7 @@ class NotificationApi(Resource):
|
||||
|
||||
notifications: list[NotificationItemDict] = []
|
||||
for notification in result.get("notifications") or []:
|
||||
contents: dict = notification.get("contents") or {}
|
||||
contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {}
|
||||
lang_content = _pick_lang_content(contents, lang)
|
||||
item: NotificationItemDict = {
|
||||
"notification_id": notification.get("notificationId"),
|
||||
|
||||
1
api/controllers/console/socketio/__init__.py
Normal file
1
api/controllers/console/socketio/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
108
api/controllers/console/socketio/workflow.py
Normal file
108
api/controllers/console/socketio/workflow.py
Normal file
@ -0,0 +1,108 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from flask import Request as FlaskRequest
|
||||
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository
|
||||
from services.account_service import AccountService
|
||||
from services.workflow_collaboration_service import WorkflowCollaborationService
|
||||
|
||||
repository = WorkflowCollaborationRepository()
|
||||
collaboration_service = WorkflowCollaborationService(repository, sio)
|
||||
|
||||
|
||||
def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]:
|
||||
return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event))
|
||||
|
||||
|
||||
@_sio_on("connect")
|
||||
def socket_connect(sid, environ, auth):
|
||||
"""
|
||||
WebSocket connect event, do authentication here.
|
||||
"""
|
||||
try:
|
||||
request_environ = FlaskRequest(environ)
|
||||
token = extract_access_token(request_environ)
|
||||
except Exception:
|
||||
logging.exception("Failed to extract token")
|
||||
token = None
|
||||
|
||||
if not token:
|
||||
logging.warning("Socket connect rejected: missing token (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
try:
|
||||
decoded = PassportService().verify(token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if not user:
|
||||
logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
if not user.has_edit_permission:
|
||||
logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
|
||||
collaboration_service.save_socket_identity(sid, user)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logging.exception("Socket authentication failed")
|
||||
return False
|
||||
|
||||
|
||||
@_sio_on("user_connect")
|
||||
def handle_user_connect(sid, data):
|
||||
"""
|
||||
Handle user connect event. Each session (tab) is treated as an independent collaborator.
|
||||
"""
|
||||
workflow_id = data.get("workflow_id")
|
||||
if not workflow_id:
|
||||
return {"msg": "workflow_id is required"}, 400
|
||||
|
||||
result = collaboration_service.authorize_and_join_workflow_room(workflow_id, sid)
|
||||
if not result:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
user_id, is_leader = result
|
||||
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
|
||||
|
||||
|
||||
@_sio_on("disconnect")
|
||||
def handle_disconnect(sid):
|
||||
"""
|
||||
Handle session disconnect event. Remove the specific session from online users.
|
||||
"""
|
||||
collaboration_service.disconnect_session(sid)
|
||||
|
||||
|
||||
@_sio_on("collaboration_event")
|
||||
def handle_collaboration_event(sid, data):
|
||||
"""
|
||||
Handle general collaboration events, include:
|
||||
1. mouse_move
|
||||
2. vars_and_features_update
|
||||
3. sync_request (ask leader to update graph)
|
||||
4. app_state_update
|
||||
5. mcp_server_update
|
||||
6. workflow_update
|
||||
7. comments_update
|
||||
8. node_panel_presence
|
||||
"""
|
||||
return collaboration_service.relay_collaboration_event(sid, data)
|
||||
|
||||
|
||||
@_sio_on("graph_event")
|
||||
def handle_graph_event(sid, data):
|
||||
"""
|
||||
Handle graph events - simple broadcast relay.
|
||||
"""
|
||||
return collaboration_service.relay_graph_event(sid, data)
|
||||
@ -1,13 +1,14 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import TagType
|
||||
from services.tag_service import (
|
||||
@ -18,17 +19,6 @@ from services.tag_service import (
|
||||
UpdateTagPayload,
|
||||
)
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"type": fields.String,
|
||||
"binding_count": fields.String,
|
||||
}
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
|
||||
|
||||
class TagBasePayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
@ -52,12 +42,36 @@ class TagListQueryParam(BaseModel):
|
||||
keyword: str | None = Field(None, description="Search keyword")
|
||||
|
||||
|
||||
class TagResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str | None = None
|
||||
binding_count: str | None = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def normalize_type(cls, value: TagType | str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, TagType):
|
||||
return value.value
|
||||
return value
|
||||
|
||||
@field_validator("binding_count", mode="before")
|
||||
@classmethod
|
||||
def normalize_binding_count(cls, value: int | str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
TagListQueryParam,
|
||||
TagResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -69,14 +83,18 @@ class TagListApi(Resource):
|
||||
@console_ns.doc(
|
||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||
)
|
||||
@marshal_with(dataset_tag_fields)
|
||||
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
raw_args = request.args.to_dict()
|
||||
param = TagListQueryParam.model_validate(raw_args)
|
||||
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
|
||||
|
||||
return tags, 200
|
||||
serialized_tags = [
|
||||
TagResponse.model_validate(tag, from_attributes=True).model_dump(mode="json") for tag in tags
|
||||
]
|
||||
|
||||
return serialized_tags, 200
|
||||
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@ -91,7 +109,9 @@ class TagListApi(Resource):
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
response = TagResponse.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
).model_dump(mode="json")
|
||||
|
||||
return response, 200
|
||||
|
||||
@ -114,7 +134,9 @@ class TagUpdateDeleteApi(Resource):
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
response = TagResponse.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
).model_dump(mode="json")
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
@ -35,22 +35,24 @@ def plugin_permission_required(
|
||||
return view(*args, **kwargs)
|
||||
|
||||
if install_required:
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
match permission.install_permission:
|
||||
case TenantPluginPermission.InstallPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
|
||||
pass
|
||||
case TenantPluginPermission.InstallPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
case TenantPluginPermission.InstallPermission.EVERYONE:
|
||||
pass
|
||||
|
||||
if debug_required:
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
match permission.debug_permission:
|
||||
case TenantPluginPermission.DebugPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
|
||||
pass
|
||||
case TenantPluginPermission.DebugPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
case TenantPluginPermission.DebugPermission.EVERYONE:
|
||||
pass
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import pytz
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -37,9 +38,10 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import Account as AccountResponse
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus, InvitationCodeStatus
|
||||
@ -74,6 +76,10 @@ class AccountAvatarPayload(BaseModel):
|
||||
avatar: str
|
||||
|
||||
|
||||
class AccountAvatarQuery(BaseModel):
|
||||
avatar: str = Field(..., description="Avatar file ID")
|
||||
|
||||
|
||||
class AccountInterfaceLanguagePayload(BaseModel):
|
||||
interface_language: str
|
||||
|
||||
@ -159,6 +165,7 @@ def reg(cls: type[BaseModel]):
|
||||
reg(AccountInitPayload)
|
||||
reg(AccountNamePayload)
|
||||
reg(AccountAvatarPayload)
|
||||
reg(AccountAvatarQuery)
|
||||
reg(AccountInterfaceLanguagePayload)
|
||||
reg(AccountInterfaceThemePayload)
|
||||
reg(AccountTimezonePayload)
|
||||
@ -174,21 +181,61 @@ reg(CheckEmailUniquePayload)
|
||||
register_schema_models(console_ns, AccountResponse)
|
||||
|
||||
|
||||
def _serialize_account(account) -> dict:
|
||||
def _serialize_account(account) -> dict[str, Any]:
|
||||
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
integrate_fields = {
|
||||
"provider": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"is_bound": fields.Boolean,
|
||||
"link": fields.String,
|
||||
}
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
integrate_model = console_ns.model("AccountIntegrate", integrate_fields)
|
||||
integrate_list_model = console_ns.model(
|
||||
"AccountIntegrateList",
|
||||
{"data": fields.List(fields.Nested(integrate_model))},
|
||||
|
||||
class AccountIntegrateResponse(ResponseModel):
|
||||
provider: str
|
||||
created_at: int | None = None
|
||||
is_bound: bool
|
||||
link: str | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AccountIntegrateListResponse(ResponseModel):
|
||||
data: list[AccountIntegrateResponse]
|
||||
|
||||
|
||||
class EducationVerifyResponse(ResponseModel):
|
||||
token: str | None = None
|
||||
|
||||
|
||||
class EducationStatusResponse(ResponseModel):
|
||||
result: bool | None = None
|
||||
is_student: bool | None = None
|
||||
expire_at: int | None = None
|
||||
allow_refresh: bool | None = None
|
||||
|
||||
@field_validator("expire_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_expire_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class EducationAutocompleteResponse(ResponseModel):
|
||||
data: list[str] = Field(default_factory=list)
|
||||
curr_page: int | None = None
|
||||
has_next: bool | None = None
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AccountIntegrateResponse,
|
||||
AccountIntegrateListResponse,
|
||||
EducationVerifyResponse,
|
||||
EducationStatusResponse,
|
||||
EducationAutocompleteResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -268,6 +315,18 @@ class AccountNameApi(Resource):
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AccountAvatarQuery.__name__])
|
||||
@console_ns.doc("get_account_avatar")
|
||||
@console_ns.doc(description="Get account avatar url")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(args.avatar)
|
||||
return {"avatar_url": avatar_url}
|
||||
|
||||
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -359,7 +418,7 @@ class AccountIntegrateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@ -395,7 +454,9 @@ class AccountIntegrateApi(Resource):
|
||||
}
|
||||
)
|
||||
|
||||
return {"data": integrate_data}
|
||||
return AccountIntegrateListResponse(
|
||||
data=[AccountIntegrateResponse.model_validate(item) for item in integrate_data]
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/delete/verify")
|
||||
@ -447,31 +508,22 @@ class AccountDeleteUpdateFeedbackApi(Resource):
|
||||
|
||||
@console_ns.route("/account/education/verify")
|
||||
class EducationVerifyApi(Resource):
|
||||
verify_fields = {
|
||||
"token": fields.String,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@marshal_with(verify_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
return BillingService.EducationIdentity.verify(account.id, account.email)
|
||||
return EducationVerifyResponse.model_validate(
|
||||
BillingService.EducationIdentity.verify(account.id, account.email) or {}
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/education")
|
||||
class EducationApi(Resource):
|
||||
status_fields = {
|
||||
"result": fields.Boolean,
|
||||
"is_student": fields.Boolean,
|
||||
"expire_at": TimestampField,
|
||||
"allow_refresh": fields.Boolean,
|
||||
}
|
||||
|
||||
@console_ns.expect(console_ns.models[EducationActivatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -491,37 +543,33 @@ class EducationApi(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@marshal_with(status_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
res = BillingService.EducationIdentity.status(account.id)
|
||||
res = BillingService.EducationIdentity.status(account.id) or {}
|
||||
# convert expire_at to UTC timestamp from isoformat
|
||||
if res and "expire_at" in res:
|
||||
res["expire_at"] = datetime.fromisoformat(res["expire_at"]).astimezone(pytz.utc)
|
||||
return res
|
||||
return EducationStatusResponse.model_validate(res).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/education/autocomplete")
|
||||
class EducationAutoCompleteApi(Resource):
|
||||
data_fields = {
|
||||
"data": fields.List(fields.String),
|
||||
"curr_page": fields.Integer,
|
||||
"has_next": fields.Boolean,
|
||||
}
|
||||
|
||||
@console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@marshal_with(data_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationAutocompleteResponse.__name__])
|
||||
def get(self):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
args = EducationAutocompleteQuery.model_validate(payload)
|
||||
|
||||
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
|
||||
return EducationAutocompleteResponse.model_validate(
|
||||
BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) or {}
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email")
|
||||
|
||||
@ -465,7 +465,7 @@ class ModelProviderModelDisableApi(Resource):
|
||||
class ParserValidate(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict
|
||||
credentials: dict[str, Any]
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields, marshal
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
@ -26,6 +27,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
|
||||
@ -58,6 +60,37 @@ class WorkspaceInfoPayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class TenantInfoResponse(ResponseModel):
|
||||
id: str
|
||||
name: str | None = None
|
||||
plan: str | None = None
|
||||
status: str | None = None
|
||||
created_at: int | None = None
|
||||
role: str | None = None
|
||||
in_trial: bool | None = None
|
||||
trial_end_reason: str | None = None
|
||||
custom_config: dict | None = None
|
||||
trial_credits: int | None = None
|
||||
trial_credits_used: int | None = None
|
||||
next_credit_reset_date: int | None = None
|
||||
|
||||
@field_validator("plan", "status", "trial_end_reason", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_like(cls, value):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None):
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
@ -66,6 +99,7 @@ reg(WorkspaceListQuery)
|
||||
reg(SwitchWorkspacePayload)
|
||||
reg(WorkspaceCustomConfigPayload)
|
||||
reg(WorkspaceInfoPayload)
|
||||
reg(TenantInfoResponse)
|
||||
|
||||
provider_fields = {
|
||||
"provider_name": fields.String,
|
||||
@ -180,7 +214,7 @@ class TenantApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(tenant_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__])
|
||||
def post(self):
|
||||
if request.path == "/info":
|
||||
logger.warning("Deprecated URL /info was used.")
|
||||
@ -200,7 +234,13 @@ class TenantApi(Resource):
|
||||
else:
|
||||
raise Unauthorized("workspace is archived")
|
||||
|
||||
return WorkspaceService.get_tenant_info(tenant), 200
|
||||
return (
|
||||
TenantInfoResponse.model_validate(
|
||||
WorkspaceService.get_tenant_info(tenant),
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/switch")
|
||||
|
||||
@ -20,7 +20,7 @@ from models.account import AccountStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
from services.operation_service import OperationService
|
||||
from services.operation_service import OperationService, UtmInfo
|
||||
|
||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
|
||||
@ -205,7 +205,7 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
|
||||
if utm_info:
|
||||
utm_info_dict: dict = json.loads(utm_info)
|
||||
utm_info_dict: UtmInfo = json.loads(utm_info)
|
||||
OperationService.record_utm(current_tenant_id, utm_info_dict)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -9,7 +9,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.wraps import setup_required
|
||||
@ -56,7 +56,7 @@ class EnterpriseAppDSLImport(Resource):
|
||||
|
||||
account.set_tenant_id(workspace_id)
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
dsl_service = AppDslService(session)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
@ -65,6 +65,10 @@ class EnterpriseAppDSLImport(Resource):
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
|
||||
@ -94,10 +94,9 @@ def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
|
||||
|
||||
def plugin_data[**P, R](
|
||||
view: Callable[P, R] | None = None,
|
||||
*,
|
||||
payload_type: type[BaseModel],
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
@ -116,7 +115,4 @@ def plugin_data[**P, R](
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import Any, Union
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from graphon.variables.input_entities import VariableEntity
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@ -158,14 +158,20 @@ class MCPAppApi(Resource):
|
||||
except ValidationError as e:
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
|
||||
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
|
||||
def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]:
|
||||
"""Convert raw user input form to VariableEntity objects"""
|
||||
return [self._create_variable_entity(item) for item in raw_form]
|
||||
|
||||
def _create_variable_entity(self, item: dict) -> VariableEntity:
|
||||
def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity:
|
||||
"""Create a single VariableEntity from raw form item"""
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
variable_type_raw: str = item.get("type", "") or list(item.keys())[0]
|
||||
try:
|
||||
variable_type = VariableEntityType(variable_type_raw)
|
||||
except ValueError as e:
|
||||
raise MCPRequestError(
|
||||
mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}"
|
||||
) from e
|
||||
variable = item[variable_type_raw]
|
||||
|
||||
return VariableEntity(
|
||||
type=variable_type,
|
||||
@ -178,7 +184,7 @@ class MCPAppApi(Resource):
|
||||
json_schema=variable.get("json_schema"),
|
||||
)
|
||||
|
||||
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
"""Parse and validate MCP request"""
|
||||
try:
|
||||
return mcp_types.ClientRequest.model_validate(args)
|
||||
|
||||
@ -12,7 +12,12 @@ from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import Annotation, AnnotationList
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.annotation_service import (
|
||||
AppAnnotationService,
|
||||
EnableAnnotationArgs,
|
||||
InsertAnnotationArgs,
|
||||
UpdateAnnotationArgs,
|
||||
)
|
||||
|
||||
|
||||
class AnnotationCreatePayload(BaseModel):
|
||||
@ -46,10 +51,15 @@ class AnnotationReplyActionApi(Resource):
|
||||
@validate_app_token
|
||||
def post(self, app_model: App, action: Literal["enable", "disable"]):
|
||||
"""Enable or disable annotation reply feature."""
|
||||
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
payload = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {})
|
||||
match action:
|
||||
case "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
|
||||
enable_args: EnableAnnotationArgs = {
|
||||
"score_threshold": payload.score_threshold,
|
||||
"embedding_provider_name": payload.embedding_provider_name,
|
||||
"embedding_model_name": payload.embedding_model_name,
|
||||
}
|
||||
result = AppAnnotationService.enable_app_annotation(enable_args, app_model.id)
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_model.id)
|
||||
return result, 200
|
||||
@ -135,8 +145,9 @@ class AnnotationListApi(Resource):
|
||||
@validate_app_token
|
||||
def post(self, app_model: App):
|
||||
"""Create a new annotation."""
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
|
||||
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
insert_args: InsertAnnotationArgs = {"question": payload.question, "answer": payload.answer}
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(insert_args, app_model.id)
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json"), HTTPStatus.CREATED
|
||||
|
||||
@ -164,8 +175,9 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
"""Update an existing annotation."""
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
|
||||
@ -3,10 +3,10 @@ import logging
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import TextToAudioPayload
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
@ -86,13 +86,6 @@ class AudioApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
voice: str | None = Field(default=None, description="Voice to use for TTS")
|
||||
text: str | None = Field(default=None, description="Text to convert to audio")
|
||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, TextToAudioPayload)
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.variables.types import SegmentType
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
@ -14,14 +16,12 @@ from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields._value_type_serializer import serialize_value_type
|
||||
from fields.base import ResponseModel
|
||||
from fields.conversation_fields import (
|
||||
ConversationInfiniteScrollPagination,
|
||||
SimpleConversation,
|
||||
)
|
||||
from fields.conversation_variable_fields import (
|
||||
build_conversation_variable_infinite_scroll_pagination_model,
|
||||
build_conversation_variable_model,
|
||||
)
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
@ -70,12 +70,70 @@ class ConversationVariableUpdatePayload(BaseModel):
|
||||
value: Any
|
||||
|
||||
|
||||
class ConversationVariableResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
value_type: str
|
||||
value: str | None = None
|
||||
description: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
@classmethod
|
||||
def normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return str(SegmentType(value).exposed_type().value)
|
||||
except ValueError:
|
||||
return value
|
||||
try:
|
||||
return serialize_value_type(value)
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
try:
|
||||
return serialize_value_type({"value_type": value})
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
value_attr = getattr(value, "value", None)
|
||||
if value_attr is not None:
|
||||
return str(value_attr)
|
||||
return str(value)
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def normalize_value(cls, value: Any | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[ConversationVariableResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
ConversationListQuery,
|
||||
ConversationRenamePayload,
|
||||
ConversationVariablesQuery,
|
||||
ConversationVariableUpdatePayload,
|
||||
ConversationVariableResponse,
|
||||
ConversationVariableInfiniteScrollPaginationResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -204,8 +262,12 @@ class ConversationVariablesApi(Resource):
|
||||
404: "Conversation not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Variables retrieved successfully",
|
||||
service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__],
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns))
|
||||
def get(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""List all variables for a conversation.
|
||||
|
||||
@ -222,9 +284,12 @@ class ConversationVariablesApi(Resource):
|
||||
last_id = str(query_args.last_id) if query_args.last_id else None
|
||||
|
||||
try:
|
||||
return ConversationService.get_conversational_variable(
|
||||
pagination = ConversationService.get_conversational_variable(
|
||||
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
|
||||
)
|
||||
return ConversationVariableInfiniteScrollPaginationResponse.model_validate(
|
||||
pagination, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@ -243,8 +308,12 @@ class ConversationVariableDetailApi(Resource):
|
||||
404: "Conversation or variable not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Variable updated successfully",
|
||||
service_api_ns.models[ConversationVariableResponse.__name__],
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns))
|
||||
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
|
||||
"""Update a conversation variable's value.
|
||||
|
||||
@ -261,9 +330,10 @@ class ConversationVariableDetailApi(Resource):
|
||||
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
return ConversationService.update_conversation_variable(
|
||||
variable = ConversationService.update_conversation_variable(
|
||||
app_model, conversation_id, variable_id, end_user, payload.value
|
||||
)
|
||||
return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json")
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationVariableNotExistsError:
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from flask_restx import Resource, fields
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
@ -33,9 +35,10 @@ from core.errors.error import (
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from fields.base import ResponseModel
|
||||
from fields.end_user_fields import SimpleEndUser
|
||||
from fields.member_fields import SimpleAccount
|
||||
from libs import helper
|
||||
from libs.helper import OptionalTimestampField, TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
@ -65,38 +68,142 @@ class WorkflowLogQuery(BaseModel):
|
||||
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
def _enum_value(value):
|
||||
return getattr(value, "value", value)
|
||||
|
||||
|
||||
class WorkflowRunStatusField(fields.Raw):
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
return obj.status.value
|
||||
return _enum_value(obj.status)
|
||||
|
||||
|
||||
class WorkflowRunOutputsField(fields.Raw):
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
if obj.status == WorkflowExecutionStatus.PAUSED:
|
||||
status = _enum_value(obj.status)
|
||||
if status == WorkflowExecutionStatus.PAUSED.value:
|
||||
return {}
|
||||
|
||||
outputs = obj.outputs_dict
|
||||
return outputs or {}
|
||||
|
||||
|
||||
workflow_run_fields = {
|
||||
"id": fields.String,
|
||||
"workflow_id": fields.String,
|
||||
"status": WorkflowRunStatusField,
|
||||
"inputs": fields.Raw,
|
||||
"outputs": WorkflowRunOutputsField,
|
||||
"error": fields.String,
|
||||
"total_steps": fields.Integer,
|
||||
"total_tokens": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
"finished_at": OptionalTimestampField,
|
||||
"elapsed_time": fields.Float,
|
||||
}
|
||||
class WorkflowRunResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
inputs: dict | list | str | int | float | bool | None = None
|
||||
outputs: dict = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
total_steps: int | None = None
|
||||
total_tokens: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
elapsed_time: float | int | None = None
|
||||
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
def build_workflow_run_model(api_or_ns: Namespace):
|
||||
"""Build the workflow run model for the API or Namespace."""
|
||||
return api_or_ns.model("WorkflowRun", workflow_run_fields)
|
||||
class WorkflowRunForLogResponse(ResponseModel):
|
||||
id: str
|
||||
version: str | None = None
|
||||
status: str | None = None
|
||||
triggered_from: str | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float | int | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
exceptions_count: int | None = None
|
||||
|
||||
@field_validator("status", "triggered_from", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum(cls, value):
|
||||
return _enum_value(value)
|
||||
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowAppLogPartialResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_run: WorkflowRunForLogResponse | None = None
|
||||
details: dict | list | str | int | float | bool | None = None
|
||||
created_from: str | None = None
|
||||
created_by_role: str | None = None
|
||||
created_by_account: SimpleAccount | None = None
|
||||
created_by_end_user: SimpleEndUser | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_from", "created_by_role", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum(cls, value):
|
||||
return _enum_value(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowAppLogPaginationResponse(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[WorkflowAppLogPartialResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
WorkflowRunResponse,
|
||||
WorkflowRunForLogResponse,
|
||||
WorkflowAppLogPartialResponse,
|
||||
WorkflowAppLogPaginationResponse,
|
||||
)
|
||||
|
||||
|
||||
def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict:
|
||||
status = _enum_value(workflow_run.status)
|
||||
raw_outputs = workflow_run.outputs_dict
|
||||
if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
|
||||
outputs: dict = {}
|
||||
elif isinstance(raw_outputs, dict):
|
||||
outputs = raw_outputs
|
||||
elif isinstance(raw_outputs, Mapping):
|
||||
outputs = dict(raw_outputs)
|
||||
else:
|
||||
outputs = {}
|
||||
return WorkflowRunResponse.model_validate(
|
||||
{
|
||||
"id": workflow_run.id,
|
||||
"workflow_id": workflow_run.workflow_id,
|
||||
"status": status,
|
||||
"inputs": workflow_run.inputs,
|
||||
"outputs": outputs,
|
||||
"error": workflow_run.error,
|
||||
"total_steps": workflow_run.total_steps,
|
||||
"total_tokens": workflow_run.total_tokens,
|
||||
"created_at": workflow_run.created_at,
|
||||
"finished_at": workflow_run.finished_at,
|
||||
"elapsed_time": workflow_run.elapsed_time,
|
||||
}
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
def _serialize_workflow_log_pagination(pagination) -> dict:
|
||||
return WorkflowAppLogPaginationResponse.model_validate(pagination, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@service_api_ns.route("/workflows/run/<string:workflow_run_id>")
|
||||
@ -112,7 +219,11 @@ class WorkflowRunDetailApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Workflow run details retrieved successfully",
|
||||
service_api_ns.models[WorkflowRunResponse.__name__],
|
||||
)
|
||||
def get(self, app_model: App, workflow_run_id: str):
|
||||
"""Get a workflow task running detail.
|
||||
|
||||
@ -133,7 +244,7 @@ class WorkflowRunDetailApi(Resource):
|
||||
)
|
||||
if not workflow_run:
|
||||
raise NotFound("Workflow run not found.")
|
||||
return workflow_run
|
||||
return _serialize_workflow_run(workflow_run)
|
||||
|
||||
|
||||
@service_api_ns.route("/workflows/run")
|
||||
@ -299,7 +410,11 @@ class WorkflowAppLogApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns))
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Logs retrieved successfully",
|
||||
service_api_ns.models[WorkflowAppLogPaginationResponse.__name__],
|
||||
)
|
||||
def get(self, app_model: App):
|
||||
"""Get workflow app logs.
|
||||
|
||||
@ -327,4 +442,4 @@ class WorkflowAppLogApi(Resource):
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
return _serialize_workflow_log_pagination(workflow_app_log_pagination)
|
||||
|
||||
@ -10,6 +10,7 @@ from sqlalchemy import desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
@ -100,15 +101,6 @@ class DocumentListQuery(BaseModel):
|
||||
status: str | None = Field(default=None, description="Document status filter")
|
||||
|
||||
|
||||
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
|
||||
|
||||
|
||||
class DocumentBatchDownloadZipPayload(BaseModel):
|
||||
"""Request payload for bulk downloading uploaded documents as a ZIP archive."""
|
||||
|
||||
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
|
||||
|
||||
|
||||
register_enum_models(service_api_ns, RetrievalMethod)
|
||||
|
||||
register_schema_models(
|
||||
|
||||
@ -2,9 +2,9 @@ from typing import Literal
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.schema import register_schema_model, register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
@ -18,11 +18,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class MetadataUpdatePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, MetadataUpdatePayload)
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
|
||||
@ -8,6 +8,7 @@ from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
@ -32,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
|
||||
"""Marshal a single segment and enrich it with summary content."""
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
|
||||
"""Marshal multiple segments and enrich them with summary content (batch query)."""
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries: dict = {}
|
||||
summaries: dict[str, str | None] = {}
|
||||
if segment_ids:
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||
|
||||
result = []
|
||||
result: list[dict[str, Any]] = []
|
||||
for segment in segments:
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
result.append(segment_dict)
|
||||
return result
|
||||
@ -69,20 +70,12 @@ class SegmentUpdatePayload(BaseModel):
|
||||
segment: SegmentUpdateArgs
|
||||
|
||||
|
||||
class ChildChunkCreatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1)
|
||||
keyword: str | None = None
|
||||
page: int = Field(default=1, ge=1)
|
||||
|
||||
|
||||
class ChildChunkUpdatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
SegmentCreatePayload,
|
||||
|
||||
@ -5,6 +5,7 @@ Web App Human Input Form APIs.
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
@ -58,10 +59,19 @@ def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
class FormDefinitionPayload(TypedDict):
|
||||
form_content: Any
|
||||
inputs: Any
|
||||
resolved_default_values: dict[str, str]
|
||||
user_actions: Any
|
||||
expiration_time: int
|
||||
site: NotRequired[dict]
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
|
||||
"""Return the form payload (optionally with site) as a JSON response."""
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
payload: FormDefinitionPayload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
@ -92,7 +102,7 @@ class HumanInputFormApi(Resource):
|
||||
_FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
# TODO(QuantumGhost): forbid submision for form tokens
|
||||
# TODO(QuantumGhost): forbid submission for form tokens
|
||||
# that are only for console.
|
||||
form = service.get_form_by_token(form_token)
|
||||
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import logging
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from jwt import InvalidTokenError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@ -20,7 +23,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import decode_jwt_token
|
||||
from libs.helper import EmailStr
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.passport import PassportService
|
||||
from libs.password import valid_password
|
||||
from libs.token import (
|
||||
@ -29,9 +32,11 @@ from libs.token import (
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
from services.app_service import AppService
|
||||
from services.entities.auth_entities import LoginPayloadBase
|
||||
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginPayload(LoginPayloadBase):
|
||||
@field_validator("password")
|
||||
@ -76,14 +81,18 @@ class LoginApi(Resource):
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
payload = LoginPayload.model_validate(web_ns.payload or {})
|
||||
normalized_email = payload.email.lower()
|
||||
|
||||
try:
|
||||
account = WebAppAuthService.authenticate(payload.email, payload.password)
|
||||
except services.errors.account.AccountLoginError:
|
||||
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
|
||||
raise AuthenticationFailedError()
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
@ -212,21 +221,30 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(payload.token)
|
||||
if token_data is None:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
if normalized_token_email != user_email:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != payload.code:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
|
||||
raise EmailCodeError()
|
||||
|
||||
WebAppAuthService.revoke_email_code_login_token(payload.token)
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
try:
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
except Unauthorized as exc:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError() from exc
|
||||
if not account:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
@ -234,3 +252,12 @@ class EmailCodeLoginApi(Resource):
|
||||
response = make_response({"result": "success", "data": {"access_token": token}})
|
||||
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
|
||||
return response
|
||||
|
||||
|
||||
def _log_web_login_failure(*, email: str, reason: LoginFailureReason) -> None:
|
||||
logger.warning(
|
||||
"Web login failed: email=%s reason=%s ip_address=%s",
|
||||
email,
|
||||
reason,
|
||||
extract_remote_ip(request),
|
||||
)
|
||||
|
||||
@ -3,10 +3,10 @@ from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
@ -25,7 +25,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -41,19 +40,6 @@ from services.message_service import MessageService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: str = Field(description="Conversation UUID")
|
||||
first_id: str | None = Field(default=None, description="First message ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||
|
||||
@field_validator("conversation_id", "first_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageMoreLikeThisQuery(BaseModel):
|
||||
response_mode: Literal["blocking", "streaming"] = Field(
|
||||
description="Response mode",
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
@ -103,21 +104,23 @@ class PassportResource(Resource):
|
||||
return response
|
||||
|
||||
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None):
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None:
|
||||
"""
|
||||
Decode the enterprise user session from the Authorization header.
|
||||
"""
|
||||
if not jwt_token:
|
||||
return None
|
||||
|
||||
decoded = PassportService().verify(jwt_token)
|
||||
decoded: dict[str, Any] = PassportService().verify(jwt_token)
|
||||
source = decoded.get("token_source")
|
||||
if not source or source != "webapp_login_token":
|
||||
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
|
||||
return decoded
|
||||
|
||||
|
||||
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
|
||||
def exchange_token_for_existing_web_user(
|
||||
app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType
|
||||
):
|
||||
"""
|
||||
Exchange a token for an existing web user session.
|
||||
"""
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_restx import fields, marshal, marshal_with
|
||||
from sqlalchemy import select
|
||||
@ -113,12 +113,12 @@ class AppSiteInfo:
|
||||
}
|
||||
|
||||
|
||||
def serialize_site(site: Site) -> dict:
|
||||
def serialize_site(site: Site) -> dict[str, Any]:
|
||||
"""Serialize Site model using the same schema as AppSiteApi."""
|
||||
return cast(dict, marshal(site, AppSiteApi.site_fields))
|
||||
return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields))
|
||||
|
||||
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
|
||||
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
|
||||
return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
|
||||
@ -11,7 +11,7 @@ from core.agent.entities import AgentScratchpadUnit
|
||||
class CotAgentOutputParser:
|
||||
@classmethod
|
||||
def handle_react_stream_output(
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any]
|
||||
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
|
||||
action_name = None
|
||||
|
||||
@ -84,7 +84,7 @@ class AgentStrategyEntity(BaseModel):
|
||||
identity: AgentStrategyIdentity
|
||||
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The description of the agent strategy")
|
||||
output_schema: dict | None = None
|
||||
output_schema: dict[str, Any] | None = None
|
||||
features: list[AgentFeature] | None = None
|
||||
meta_version: str | None = None
|
||||
# pydantic configs
|
||||
|
||||
@ -22,8 +22,8 @@ class SensitiveWordAvoidanceConfigManager:
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
|
||||
) -> tuple[dict, list[str]]:
|
||||
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
if not config.get("sensitive_word_avoidance"):
|
||||
config["sensitive_word_avoidance"] = {"enabled": False}
|
||||
|
||||
|
||||
@ -138,7 +138,9 @@ class DatasetConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for dataset feature
|
||||
|
||||
@ -172,7 +174,7 @@ class DatasetConfigManager:
|
||||
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
|
||||
|
||||
@classmethod
|
||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
|
||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]):
|
||||
"""
|
||||
Extract dataset config for legacy compatibility
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ class ModelConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for model config
|
||||
|
||||
@ -108,7 +108,7 @@ class ModelConfigManager:
|
||||
return dict(config), ["model"]
|
||||
|
||||
@classmethod
|
||||
def validate_model_completion_params(cls, cp: dict):
|
||||
def validate_model_completion_params(cls, cp: dict[str, Any]):
|
||||
# model.completion_params
|
||||
if not isinstance(cp, dict):
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
@ -65,7 +65,7 @@ class PromptTemplateConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate pre_prompt and set defaults for prompt feature
|
||||
depending on the config['model']
|
||||
@ -130,7 +130,7 @@ class PromptTemplateConfigManager:
|
||||
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
|
||||
|
||||
@classmethod
|
||||
def validate_post_prompt_and_set_defaults(cls, config: dict):
|
||||
def validate_post_prompt_and_set_defaults(cls, config: dict[str, Any]):
|
||||
"""
|
||||
Validate post_prompt and set defaults for prompt feature
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
@ -82,7 +82,7 @@ class BasicVariablesConfigManager:
|
||||
return variable_entities, external_data_variables
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for user input form
|
||||
|
||||
@ -99,7 +99,7 @@ class BasicVariablesConfigManager:
|
||||
return config, related_config_keys
|
||||
|
||||
@classmethod
|
||||
def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_variables_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for user input form
|
||||
|
||||
@ -164,7 +164,9 @@ class BasicVariablesConfigManager:
|
||||
return config, ["user_input_form"]
|
||||
|
||||
@classmethod
|
||||
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_external_data_tools_and_set_defaults(
|
||||
cls, tenant_id: str, config: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for external data fetch feature
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ class FileUploadConfigManager:
|
||||
return FileUploadConfig.model_validate(file_upload_dict)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for file upload feature
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
|
||||
@ -13,7 +15,7 @@ class AppConfigModel(BaseModel):
|
||||
|
||||
class MoreLikeThisConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -23,7 +25,7 @@ class MoreLikeThisConfigManager:
|
||||
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
try:
|
||||
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
|
||||
except ValidationError:
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class OpeningStatementConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> tuple[str, list]:
|
||||
def convert(cls, config: dict[str, Any]) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,7 +18,7 @@ class OpeningStatementConfigManager:
|
||||
return opening_statement, suggested_questions_list
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for opening statement feature
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class RetrievalResourceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
show_retrieve_source = False
|
||||
retriever_resource_dict = config.get("retriever_resource")
|
||||
if retriever_resource_dict:
|
||||
@ -10,7 +13,7 @@ class RetrievalResourceConfigManager:
|
||||
return show_retrieve_source
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for retriever resource feature
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SpeechToTextConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,7 +18,7 @@ class SpeechToTextConfigManager:
|
||||
return speech_to_text
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for speech to text feature
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,7 +18,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
return suggested_questions_after_answer
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for suggested questions feature
|
||||
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import TextToSpeechEntity
|
||||
|
||||
|
||||
class TextToSpeechConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict):
|
||||
def convert(cls, config: dict[str, Any]):
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -22,7 +24,7 @@ class TextToSpeechConfigManager:
|
||||
return text_to_speech
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for text to speech feature
|
||||
|
||||
|
||||
@ -57,7 +57,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
) -> Generator[dict[str, Any] | str, Any, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -88,7 +88,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
) -> Generator[dict[str, Any] | str, Any, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -56,7 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -87,7 +87,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@ -24,7 +24,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
|
||||
def _generate_full_response() -> Generator[dict | str, Any, None]:
|
||||
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||
yield from cls.convert_stream_full_response(response)
|
||||
|
||||
return _generate_full_response()
|
||||
@ -33,7 +33,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
else:
|
||||
|
||||
def _generate_simple_response() -> Generator[dict | str, Any, None]:
|
||||
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||
yield from cls.convert_stream_simple_response(response)
|
||||
|
||||
return _generate_simple_response()
|
||||
@ -52,14 +52,14 @@ class AppGenerateResponseConverter(ABC):
|
||||
@abstractmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -56,7 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -87,7 +87,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -55,7 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -85,7 +85,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -37,7 +37,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -66,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user