mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:15:58 +08:00
Compare commits
130 Commits
feat/new-a
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| dbceb3067e | |||
| 425457cb16 | |||
| e5bd18132c | |||
| 2f33867d07 | |||
| fd71c56f16 | |||
| e3c2116501 | |||
| fb17339d89 | |||
| 9fd196642d | |||
| 98897a5379 | |||
| 5542329554 | |||
| 79332c0e5e | |||
| 50a55513d4 | |||
| 3bccdd6c9a | |||
| 76af80e332 | |||
| 7a880ae60c | |||
| 5bc0f9513b | |||
| b77801ece9 | |||
| 7de92c598f | |||
| 693080aa12 | |||
| 25c388d0db | |||
| b1722c8af9 | |||
| b65a5fcd97 | |||
| 1c3cba281a | |||
| 800954f8ce | |||
| f66a3c49c4 | |||
| ef396ac84e | |||
| 7e7b27fdec | |||
| 9c90c1c455 | |||
| b1df52b8ff | |||
| e527b7c5f1 | |||
| 149b9d4c0f | |||
| ef28a63ad3 | |||
| e78558bc06 | |||
| f63d7c4121 | |||
| ef062fb397 | |||
| a2ea7ca039 | |||
| 6876cd787b | |||
| 50a6892c3a | |||
| 1bcc7f78c7 | |||
| 2fd5b76ac1 | |||
| 62f42b3f24 | |||
| 2c58b424a1 | |||
| 381c518b23 | |||
| ebf741114d | |||
| 648dde5e96 | |||
| a3042e6332 | |||
| e5fd3133f4 | |||
| e1bbe57f9c | |||
| d4783e8c14 | |||
| 736880e046 | |||
| bd7a9b5fcf | |||
| 9a47bb2f80 | |||
| d7ad2baf79 | |||
| a951cc996b | |||
| 173e0d6f35 | |||
| 62bb830338 | |||
| f7c6270f74 | |||
| 711fe6ba2c | |||
| fbedb60371 | |||
| 974d2f1627 | |||
| ed401728eb | |||
| fc389a54c5 | |||
| c8b372dba0 | |||
| 2333d75c56 | |||
| 2ef9a8a769 | |||
| 21ab9b9d8c | |||
| 79c1473378 | |||
| 93b8a74351 | |||
| 28185170b0 | |||
| 178883b4cc | |||
| e9f9041b25 | |||
| 175290fa04 | |||
| b0c4d8c541 | |||
| 0f643bca76 | |||
| eeebedcfe8 | |||
| 2f682780fa | |||
| ed83f5369e | |||
| 4ee1bd5f32 | |||
| 1c2bbed405 | |||
| d573fc0e65 | |||
| f8b249e649 | |||
| fbcab757d5 | |||
| c0e998ef6e | |||
| 84f25807db | |||
| 83b242be7b | |||
| a12d740a5d | |||
| 3bbb014dc7 | |||
| f040733e28 | |||
| b0bf7ca486 | |||
| 14d83c8bac | |||
| 8b506dfa42 | |||
| ac2258c2dc | |||
| 3c279edcf2 | |||
| 9ed8a5ed73 | |||
| 3d4ddf4a6f | |||
| 4e0273bb28 | |||
| 7056d2ae99 | |||
| d8fbc00cb9 | |||
| 57c5f0ec87 | |||
| e5bd80c719 | |||
| 25a33a454c | |||
| bd30784b1d | |||
| 28fce0a890 | |||
| e1eb582bea | |||
| 2042ee453b | |||
| 33c4e512f1 | |||
| 253e8a3f98 | |||
| 06b63d65d1 | |||
| 08f3133414 | |||
| d412cddf39 | |||
| 671c5cdd84 | |||
| 554f060092 | |||
| e243e8d8a3 | |||
| 1b935a367f | |||
| 2edd083a71 | |||
| dd50a68bf2 | |||
| e8dd3461e8 | |||
| 8dd4473432 | |||
| b5bbbdd840 | |||
| f0266e13c5 | |||
| ae898652b2 | |||
| c34f67495c | |||
| 815c536e05 | |||
| fc64427ae1 | |||
| 11c518478e | |||
| e823635ce1 | |||
| 98e74c8fde | |||
| 29bfa33d59 | |||
| 3ead0beeb1 | |||
| 2108c44c8b |
79
.agents/skills/e2e-cucumber-playwright/SKILL.md
Normal file
79
.agents/skills/e2e-cucumber-playwright/SKILL.md
Normal file
@ -0,0 +1,79 @@
|
||||
---
|
||||
name: e2e-cucumber-playwright
|
||||
description: Write, update, or review Dify end-to-end tests under `e2e/` that use Cucumber, Gherkin, and Playwright. Use when the task involves `.feature` files, `features/step-definitions/`, `features/support/`, `DifyWorld`, scenario tags, locator/assertion choices, or E2E testing best practices for this repository.
|
||||
---
|
||||
|
||||
# Dify E2E Cucumber + Playwright
|
||||
|
||||
Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) as the canonical guide for local architecture and conventions, then apply Playwright/Cucumber best practices only where they fit the current suite.
|
||||
|
||||
## Scope
|
||||
|
||||
- Use this skill for `.feature` files, Cucumber step definitions, `DifyWorld`, hooks, tags, and E2E review work under `e2e/`.
|
||||
- Do not use this skill for Vitest or React Testing Library work under `web/`; use `frontend-testing` instead.
|
||||
- Do not use this skill for backend test or API review tasks under `api/`.
|
||||
|
||||
## Read Order
|
||||
|
||||
1. Read [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) first.
|
||||
2. Read only the files directly involved in the task:
|
||||
- target `.feature` files under `e2e/features/`
|
||||
- related step files under `e2e/features/step-definitions/`
|
||||
- `e2e/features/support/hooks.ts` and `e2e/features/support/world.ts` when session lifecycle or shared state matters
|
||||
- `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter
|
||||
3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved.
|
||||
4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved.
|
||||
5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern.
|
||||
|
||||
## Local Rules
|
||||
|
||||
- `e2e/` uses Cucumber for scenarios and Playwright as the browser layer.
|
||||
- `DifyWorld` is the per-scenario context object. Type `this` as `DifyWorld` and use `async function`, not arrow functions.
|
||||
- Keep glue organized by capability under `e2e/features/step-definitions/`; use `common/` only for broadly reusable steps.
|
||||
- Browser session behavior comes from `features/support/hooks.ts`:
|
||||
- default: authenticated session with shared storage state
|
||||
- `@unauthenticated`: clean browser context
|
||||
- `@authenticated`: readability/selective-run tag only unless implementation changes
|
||||
- `@fresh`: only for `e2e:full*` flows
|
||||
- Do not import Playwright Test runner patterns that bypass the current Cucumber + `DifyWorld` architecture unless the task is explicitly about changing that architecture.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Rebuild local context.
|
||||
- Inspect the target feature area.
|
||||
- Reuse an existing step when wording and behavior already match.
|
||||
- Add a new step only for a genuinely new user action or assertion.
|
||||
- Keep edits close to the current capability folder unless the step is broadly reusable.
|
||||
2. Write behavior-first scenarios.
|
||||
- Describe user-observable behavior, not DOM mechanics.
|
||||
- Keep each scenario focused on one workflow or outcome.
|
||||
- Keep scenarios independent and re-runnable.
|
||||
3. Write step definitions in the local style.
|
||||
- Keep one step to one user-visible action or one assertion.
|
||||
- Prefer Cucumber Expressions such as `{string}` and `{int}`.
|
||||
- Scope locators to stable containers when the page has repeated elements.
|
||||
- Avoid page-object layers or extra helper abstractions unless repeated complexity clearly justifies them.
|
||||
4. Use Playwright in the local style.
|
||||
- Prefer user-facing locators: `getByRole`, `getByLabel`, `getByPlaceholder`, `getByText`, then `getByTestId` for explicit contracts.
|
||||
- Use web-first `expect(...)` assertions.
|
||||
- Do not use `waitForTimeout`, manual polling, or raw visibility checks when a locator action or retrying assertion already expresses the behavior.
|
||||
5. Validate narrowly.
|
||||
- Run the narrowest tagged scenario or flow that exercises the change.
|
||||
- Run `pnpm -C e2e check`.
|
||||
- Broaden verification only when the change affects hooks, tags, setup, or shared step semantics.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
- Does the scenario describe behavior rather than implementation?
|
||||
- Does it fit the current session model, tags, and `DifyWorld` usage?
|
||||
- Should an existing step be reused instead of adding a new one?
|
||||
- Are locators user-facing and assertions web-first?
|
||||
- Does the change introduce hidden coupling across scenarios, tags, or instance state?
|
||||
- Does it document or implement behavior that differs from the real hooks or configuration?
|
||||
|
||||
Lead findings with correctness, flake risk, and architecture drift.
|
||||
|
||||
## References
|
||||
|
||||
- [`references/playwright-best-practices.md`](references/playwright-best-practices.md)
|
||||
- [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md)
|
||||
@ -0,0 +1,4 @@
|
||||
interface:
|
||||
display_name: "E2E Cucumber + Playwright"
|
||||
short_description: "Write and review Dify E2E scenarios."
|
||||
default_prompt: "Use $e2e-cucumber-playwright to write or review a Dify E2E scenario under e2e/."
|
||||
@ -0,0 +1,93 @@
|
||||
# Cucumber Best Practices For Dify E2E
|
||||
|
||||
Use this reference when writing or reviewing Gherkin scenarios, step definitions, parameter expressions, and step reuse in Dify's `e2e/` suite.
|
||||
|
||||
Official sources:
|
||||
|
||||
- https://cucumber.io/docs/guides/10-minute-tutorial/
|
||||
- https://cucumber.io/docs/cucumber/step-definitions/
|
||||
- https://cucumber.io/docs/cucumber/cucumber-expressions/
|
||||
|
||||
## What Matters Most
|
||||
|
||||
### 1. Treat scenarios as executable specifications
|
||||
|
||||
Cucumber scenarios should describe examples of behavior, not test implementation recipes.
|
||||
|
||||
Apply it like this:
|
||||
|
||||
- write what the user does and what should happen
|
||||
- avoid UI-internal wording such as selector details, DOM structure, or component names
|
||||
- keep language concrete enough that the scenario reads like living documentation
|
||||
|
||||
### 2. Keep scenarios focused
|
||||
|
||||
A scenario should usually prove one workflow or business outcome. If a scenario wanders across several unrelated behaviors, split it.
|
||||
|
||||
In Dify's suite, this means:
|
||||
|
||||
- one capability-focused scenario per feature path
|
||||
- no long setup chains when existing bootstrap or reusable steps already cover them
|
||||
- no hidden dependency on another scenario's side effects
|
||||
|
||||
### 3. Reuse steps, but only when behavior really matches
|
||||
|
||||
Good reuse reduces duplication. Bad reuse hides meaning.
|
||||
|
||||
Prefer reuse when:
|
||||
|
||||
- the user action is genuinely the same
|
||||
- the expected outcome is genuinely the same
|
||||
- the wording stays natural across features
|
||||
|
||||
Write a new step when:
|
||||
|
||||
- the behavior is materially different
|
||||
- reusing the old wording would make the scenario misleading
|
||||
- a supposedly generic step would become an implementation-detail wrapper
|
||||
|
||||
### 4. Prefer Cucumber Expressions
|
||||
|
||||
Use Cucumber Expressions for parameters unless regex is clearly necessary.
|
||||
|
||||
Common examples:
|
||||
|
||||
- `{string}` for labels, names, and visible text
|
||||
- `{int}` for counts
|
||||
- `{float}` for decimal values
|
||||
- `{word}` only when the value is truly a single token
|
||||
|
||||
Keep expressions readable. If a step needs complicated parsing logic, first ask whether the scenario wording should be simpler.
|
||||
|
||||
### 5. Keep step definitions thin and meaningful
|
||||
|
||||
Step definitions are glue between Gherkin and automation, not a second abstraction language.
|
||||
|
||||
For Dify:
|
||||
|
||||
- type `this` as `DifyWorld`
|
||||
- use `async function`
|
||||
- keep each step to one user-visible action or assertion
|
||||
- rely on `DifyWorld` and existing support code for shared context
|
||||
- avoid leaking cross-scenario state
|
||||
|
||||
### 6. Use tags intentionally
|
||||
|
||||
Tags should communicate run scope or session semantics, not become ad hoc metadata.
|
||||
|
||||
In Dify's current suite:
|
||||
|
||||
- capability tags group related scenarios
|
||||
- `@unauthenticated` changes session behavior
|
||||
- `@authenticated` is descriptive/selective, not a behavior switch by itself
|
||||
- `@fresh` belongs to reset/full-install flows only
|
||||
|
||||
If a proposed tag implies behavior, verify that hooks or runner configuration actually implement it.
|
||||
|
||||
## Review Questions
|
||||
|
||||
- Does the scenario read like a real example of product behavior?
|
||||
- Are the steps behavior-oriented instead of implementation-oriented?
|
||||
- Is a reused step still truthful in this feature?
|
||||
- Is a new tag documenting real behavior, or inventing semantics that the suite does not implement?
|
||||
- Would a new reader understand the outcome without opening the step-definition file?
|
||||
@ -0,0 +1,96 @@
|
||||
# Playwright Best Practices For Dify E2E
|
||||
|
||||
Use this reference when writing or reviewing locator, assertion, isolation, or synchronization logic for Dify's Cucumber-based E2E suite.
|
||||
|
||||
Official sources:
|
||||
|
||||
- https://playwright.dev/docs/best-practices
|
||||
- https://playwright.dev/docs/locators
|
||||
- https://playwright.dev/docs/test-assertions
|
||||
- https://playwright.dev/docs/browser-contexts
|
||||
|
||||
## What Matters Most
|
||||
|
||||
### 1. Keep scenarios isolated
|
||||
|
||||
Playwright's model is built around clean browser contexts so one test does not leak into another. In Dify's suite, that principle maps to per-scenario session setup in `features/support/hooks.ts` and `DifyWorld`.
|
||||
|
||||
Apply it like this:
|
||||
|
||||
- do not depend on another scenario having run first
|
||||
- do not persist ad hoc scenario state outside `DifyWorld`
|
||||
- do not couple ordinary scenarios to `@fresh` behavior
|
||||
- when a flow needs special auth/session semantics, express that through the existing tag model or explicit hook changes
|
||||
|
||||
### 2. Prefer user-facing locators
|
||||
|
||||
Playwright recommends built-in locators that reflect what users perceive on the page.
|
||||
|
||||
Preferred order in this repository:
|
||||
|
||||
1. `getByRole`
|
||||
2. `getByLabel`
|
||||
3. `getByPlaceholder`
|
||||
4. `getByText`
|
||||
5. `getByTestId` when an explicit test contract is the most stable option
|
||||
|
||||
Avoid raw CSS/XPath selectors unless no stable user-facing contract exists and adding one is not practical.
|
||||
|
||||
Also remember:
|
||||
|
||||
- repeated content usually needs scoping to a stable container
|
||||
- exact text matching is often too brittle when role/name or label already exists
|
||||
- `getByTestId` is acceptable when semantics are weak but the contract is intentional
|
||||
|
||||
### 3. Use web-first assertions
|
||||
|
||||
Playwright assertions auto-wait and retry. Prefer them over manual state inspection.
|
||||
|
||||
Prefer:
|
||||
|
||||
- `await expect(page).toHaveURL(...)`
|
||||
- `await expect(locator).toBeVisible()`
|
||||
- `await expect(locator).toBeHidden()`
|
||||
- `await expect(locator).toBeEnabled()`
|
||||
- `await expect(locator).toHaveText(...)`
|
||||
|
||||
Avoid:
|
||||
|
||||
- `expect(await locator.isVisible()).toBe(true)`
|
||||
- custom polling loops for DOM state
|
||||
- `waitForTimeout` as synchronization
|
||||
|
||||
If a condition genuinely needs custom retry logic, use Playwright's polling/assertion tools deliberately and keep that choice local and explicit.
|
||||
|
||||
### 4. Let actions wait for actionability
|
||||
|
||||
Locator actions already wait for the element to be actionable. Do not preface every click/fill with extra timing logic unless the action needs a specific visible/ready assertion for clarity.
|
||||
|
||||
Good pattern:
|
||||
|
||||
- assert a meaningful visible state when that is part of the behavior
|
||||
- then click/fill/select via locator APIs
|
||||
|
||||
Bad pattern:
|
||||
|
||||
- stack arbitrary waits before every action
|
||||
- wait on unstable implementation details instead of the visible state the user cares about
|
||||
|
||||
### 5. Match debugging to the current suite
|
||||
|
||||
Playwright's wider ecosystem supports traces and rich debugging tools. Dify's current suite already captures:
|
||||
|
||||
- full-page screenshots
|
||||
- page HTML
|
||||
- console errors
|
||||
- page errors
|
||||
|
||||
Use the existing artifact flow by default. If a task is specifically about improving diagnostics, confirm the change fits the current Cucumber architecture before importing broader Playwright tooling.
|
||||
|
||||
## Review Questions
|
||||
|
||||
- Would this locator survive DOM refactors that do not change user-visible behavior?
|
||||
- Is this assertion using Playwright's retrying semantics?
|
||||
- Is any explicit wait masking a real readiness problem?
|
||||
- Does this code preserve per-scenario isolation?
|
||||
- Is a new abstraction really needed, or does it bypass the existing `DifyWorld` + step-definition model?
|
||||
1
.claude/skills/e2e-cucumber-playwright
Symbolic link
1
.claude/skills/e2e-cucumber-playwright
Symbolic link
@ -0,0 +1 @@
|
||||
../../.agents/skills/e2e-cucumber-playwright
|
||||
7
.github/workflows/docker-build.yml
vendored
7
.github/workflows/docker-build.yml
vendored
@ -6,14 +6,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- api/Dockerfile
|
||||
- web/docker/**
|
||||
- web/Dockerfile
|
||||
- packages/**
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
- .nvmrc
|
||||
|
||||
concurrency:
|
||||
group: docker-build-${{ github.head_ref || github.run_id }}
|
||||
|
||||
1
.github/workflows/main-ci.yml
vendored
1
.github/workflows/main-ci.yml
vendored
@ -92,6 +92,7 @@ jobs:
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'api/tests/integration_tests/vdb/**'
|
||||
- 'api/providers/vdb/*/tests/**'
|
||||
- '.github/workflows/vdb-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
|
||||
@ -62,7 +62,7 @@ jobs:
|
||||
- name: Render coverage markdown from structured data
|
||||
id: render
|
||||
run: |
|
||||
comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \
|
||||
comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \
|
||||
--base base_report.json \
|
||||
< pr_report.json)"
|
||||
|
||||
|
||||
2
.github/workflows/vdb-tests-full.yml
vendored
2
.github/workflows/vdb-tests-full.yml
vendored
@ -89,7 +89,7 @@ jobs:
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
10
.github/workflows/vdb-tests.yml
vendored
10
.github/workflows/vdb-tests.yml
vendored
@ -81,12 +81,12 @@ jobs:
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: |
|
||||
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/vdb/chroma \
|
||||
api/tests/integration_tests/vdb/pgvector \
|
||||
api/tests/integration_tests/vdb/qdrant \
|
||||
api/tests/integration_tests/vdb/weaviate
|
||||
api/providers/vdb/vdb-chroma/tests/integration_tests \
|
||||
api/providers/vdb/vdb-pgvector/tests/integration_tests \
|
||||
api/providers/vdb/vdb-qdrant/tests/integration_tests \
|
||||
api/providers/vdb/vdb-weaviate/tests/integration_tests
|
||||
|
||||
@ -57,6 +57,9 @@ REDIS_SSL_CERTFILE=
|
||||
REDIS_SSL_KEYFILE=
|
||||
# Path to client private key file for SSL authentication
|
||||
REDIS_DB=0
|
||||
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
|
||||
# Leave empty to preserve current unprefixed behavior.
|
||||
REDIS_KEY_PREFIX=
|
||||
|
||||
# redis Sentinel configuration.
|
||||
REDIS_USE_SENTINEL=false
|
||||
|
||||
@ -69,8 +69,6 @@ ignore = [
|
||||
"FURB152", # math-constant
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"UP045", # non-pep604-annotation-optional
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
@ -84,7 +82,6 @@ ignore = [
|
||||
"SIM102", # collapsible-if
|
||||
"SIM103", # needless-bool
|
||||
"SIM105", # suppressible-exception
|
||||
"SIM107", # return-in-try-except-finally
|
||||
"SIM108", # if-else-block-instead-of-if-exp
|
||||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
@ -93,29 +90,16 @@ ignore = [
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"__init__.py" = [
|
||||
"F401", # unused-import
|
||||
"F811", # redefined-while-unused
|
||||
]
|
||||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"graphon/model_runtime/callbacks/base_callback.py" = ["T201"]
|
||||
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests,
|
||||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
|
||||
]
|
||||
"controllers/console/explore/trial.py" = ["TID251"]
|
||||
"controllers/console/human_input_form.py" = ["TID251"]
|
||||
"controllers/web/human_input_form.py" = ["TID251"]
|
||||
|
||||
[lint.flake8-tidy-imports]
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
|
||||
msg = "Use Pydantic payload/query models instead of reqparse."
|
||||
|
||||
@ -21,8 +21,9 @@ RUN apt-get update \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies
|
||||
# Install Python dependencies (workspace members under providers/vdb/)
|
||||
COPY pyproject.toml uv.lock ./
|
||||
COPY providers ./providers
|
||||
RUN uv sync --locked --no-dev
|
||||
|
||||
# production stage
|
||||
|
||||
@ -341,11 +341,10 @@ def add_qdrant_index(field: str):
|
||||
click.echo(click.style("No dataset collection bindings found.", fg="red"))
|
||||
return
|
||||
import qdrant_client
|
||||
from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.models import PayloadSchemaType
|
||||
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
||||
|
||||
for binding in bindings:
|
||||
if dify_config.QDRANT_URL is None:
|
||||
raise ValueError("Qdrant URL is required.")
|
||||
|
||||
@ -287,27 +287,6 @@ class MarketplaceConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CreatorsPlatformConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for creators platform
|
||||
"""
|
||||
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
|
||||
description="Enable or disable creators platform features",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
|
||||
description="Creators Platform API URL",
|
||||
default=HttpUrl("https://creators.dify.ai"),
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
|
||||
description="OAuth client_id for the Creators Platform app registered in Dify",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@ -362,15 +341,6 @@ class FileAccessConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
FILES_API_URL: str = Field(
|
||||
description="Base URL for storage file ticket API endpoints."
|
||||
" Used by sandbox containers (internal or external like e2b) that need"
|
||||
" an absolute, routable address to upload/download files via the API."
|
||||
" For all-in-one Docker deployments, set to http://localhost."
|
||||
" For public sandbox environments, set to a public domain or IP.",
|
||||
default="",
|
||||
)
|
||||
|
||||
FILES_ACCESS_TIMEOUT: int = Field(
|
||||
description="Expiration time in seconds for file access URLs",
|
||||
default=300,
|
||||
@ -1304,52 +1274,6 @@ class PositionConfig(BaseSettings):
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class CollaborationConfig(BaseSettings):
|
||||
ENABLE_COLLABORATION_MODE: bool = Field(
|
||||
description="Whether to enable collaboration mode features across the workspace",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
|
||||
description="Graceful period in days for sandbox records clean after subscription expiration",
|
||||
default=21,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Maximum number of records to process in each batch",
|
||||
default=1000,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL: PositiveInt = Field(
|
||||
description="Maximum interval in milliseconds between batches",
|
||||
default=200,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Retention days for sandbox expired workflow_run records and message records",
|
||||
default=30,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: PositiveInt = Field(
|
||||
description="Lock TTL for sandbox expired records clean task in seconds",
|
||||
default=90000,
|
||||
)
|
||||
|
||||
|
||||
class AgentV2UpgradeConfig(BaseSettings):
|
||||
"""Feature flags for transparent Agent V2 upgrade."""
|
||||
|
||||
AGENT_V2_TRANSPARENT_UPGRADE: bool = Field(
|
||||
description="Transparently run old apps (chat/completion/agent-chat) through the Agent V2 workflow engine. "
|
||||
"When enabled, old apps synthesize a virtual workflow at runtime instead of using legacy runners.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
AGENT_V2_REPLACES_LLM: bool = Field(
|
||||
description="Transparently replace LLM nodes in workflows with Agent V2 nodes at runtime. "
|
||||
"LLMNodeData is remapped to AgentV2NodeData with tools=[] (identical behavior).",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class LoginConfig(BaseSettings):
|
||||
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
|
||||
description="whether to enable email code login",
|
||||
@ -1419,6 +1343,29 @@ class TenantIsolatedTaskQueueConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
|
||||
description="Graceful period in days for sandbox records clean after subscription expiration",
|
||||
default=21,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Maximum number of records to process in each batch",
|
||||
default=1000,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL: PositiveInt = Field(
|
||||
description="Maximum interval in milliseconds between batches",
|
||||
default=200,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Retention days for sandbox expired workflow_run records and message records",
|
||||
default=30,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: PositiveInt = Field(
|
||||
description="Lock TTL for sandbox expired records clean task in seconds",
|
||||
default=90000,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@ -1429,7 +1376,6 @@ class FeatureConfig(
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
CreatorsPlatformConfig,
|
||||
DataSetConfig,
|
||||
EndpointConfig,
|
||||
FileAccessConfig,
|
||||
@ -1445,6 +1391,7 @@ class FeatureConfig(
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SandboxExpiredRecordsCleanConfig,
|
||||
SecurityConfig,
|
||||
TenantIsolatedTaskQueueConfig,
|
||||
ToolConfig,
|
||||
@ -1452,9 +1399,6 @@ class FeatureConfig(
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
CollaborationConfig,
|
||||
AgentV2UpgradeConfig,
|
||||
SandboxExpiredRecordsCleanConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
|
||||
@ -160,6 +160,16 @@ class DatabaseConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
DB_SESSION_TIMEZONE_OVERRIDE: str = Field(
|
||||
description=(
|
||||
"PostgreSQL session timezone override injected via startup options."
|
||||
" Default is 'UTC' for out-of-the-box consistency."
|
||||
" Set to empty string to disable app-level timezone injection, for example when using RDS Proxy"
|
||||
" together with a database-side default timezone."
|
||||
),
|
||||
default="UTC",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
|
||||
@ -227,12 +237,13 @@ class DatabaseConfig(BaseSettings):
|
||||
connect_args: dict[str, str] = {}
|
||||
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
|
||||
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
connect_args = {"options": merged_options}
|
||||
merged_options = options.strip()
|
||||
session_timezone_override = self.DB_SESSION_TIMEZONE_OVERRIDE.strip()
|
||||
if session_timezone_override:
|
||||
timezone_opt = f"-c timezone={session_timezone_override}"
|
||||
merged_options = f"{merged_options} {timezone_opt}".strip() if merged_options else timezone_opt
|
||||
if merged_options:
|
||||
connect_args = {"options": merged_options}
|
||||
|
||||
result: SQLAlchemyEngineOptionsDict = {
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
|
||||
5
api/configs/middleware/cache/redis_config.py
vendored
5
api/configs/middleware/cache/redis_config.py
vendored
@ -32,6 +32,11 @@ class RedisConfig(BaseSettings):
|
||||
default=0,
|
||||
)
|
||||
|
||||
REDIS_KEY_PREFIX: str = Field(
|
||||
description="Optional global prefix for Redis keys, topics, and transport artifacts",
|
||||
default="",
|
||||
)
|
||||
|
||||
REDIS_USE_SSL: bool = Field(
|
||||
description="Enable SSL/TLS for the Redis connection",
|
||||
default=False,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@ -42,17 +41,17 @@ class HologresConfig(BaseSettings):
|
||||
default="public",
|
||||
)
|
||||
|
||||
HOLOGRES_TOKENIZER: TokenizerType = Field(
|
||||
HOLOGRES_TOKENIZER: str = Field(
|
||||
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
|
||||
default="jieba",
|
||||
)
|
||||
|
||||
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
|
||||
HOLOGRES_DISTANCE_METHOD: str = Field(
|
||||
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
|
||||
default="Cosine",
|
||||
)
|
||||
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field(
|
||||
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
|
||||
default="rabitq",
|
||||
)
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
"""Configuration for InterSystems IRIS vector database."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, PositiveInt, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@ -64,7 +66,7 @@ class IrisVectorConfig(BaseSettings):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate IRIS configuration values.
|
||||
|
||||
Args:
|
||||
|
||||
@ -81,20 +81,4 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
||||
},
|
||||
},
|
||||
},
|
||||
# agent default mode (new agent backed by single-node workflow)
|
||||
AppMode.AGENT: {
|
||||
"app": {
|
||||
"mode": AppMode.AGENT,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateArgs, AdvancedPromptTemplateService
|
||||
|
||||
|
||||
class AdvancedPromptTemplateQuery(BaseModel):
|
||||
@ -35,5 +35,10 @@ class AdvancedPromptTemplateList(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
|
||||
prompt_args: AdvancedPromptTemplateArgs = {
|
||||
"app_mode": args.app_mode,
|
||||
"model_mode": args.model_mode,
|
||||
"model_name": args.model_name,
|
||||
"has_context": args.has_context,
|
||||
}
|
||||
return AdvancedPromptTemplateService.get_prompt(prompt_args)
|
||||
|
||||
@ -6,10 +6,9 @@ from typing import Any, Literal
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.helpers import FileInfo
|
||||
@ -31,13 +30,14 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import build_icon_url
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportMode
|
||||
from services.entities.dsl_entities import ImportMode, ImportStatus
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DataSource,
|
||||
InfoList,
|
||||
@ -52,7 +52,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion", "agent"]
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
@ -62,7 +62,7 @@ _logger = logging.getLogger(__name__)
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
@ -94,9 +94,7 @@ class AppListQuery(BaseModel):
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion", "agent"] = Field(
|
||||
..., description="App mode"
|
||||
)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
@ -163,15 +161,6 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
return value
|
||||
|
||||
|
||||
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
|
||||
if icon is None or icon_type is None:
|
||||
return None
|
||||
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
|
||||
if icon_type_value.lower() != IconType.IMAGE:
|
||||
return None
|
||||
return file_helpers.get_signed_file_url(icon)
|
||||
|
||||
|
||||
class Tag(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -294,7 +283,7 @@ class Site(ResponseModel):
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
return build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
@field_validator("icon_type", mode="before")
|
||||
@classmethod
|
||||
@ -344,7 +333,7 @@ class AppPartial(ResponseModel):
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
return build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
@ -392,7 +381,7 @@ class AppDetailWithSite(AppDetail):
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
return build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
|
||||
class AppPagination(ResponseModel):
|
||||
@ -634,7 +623,7 @@ class AppCopyApi(Resource):
|
||||
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
|
||||
result = import_service.import_app(
|
||||
@ -647,6 +636,13 @@ class AppCopyApi(Resource):
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
return result.model_dump(mode="json"), 400
|
||||
if result.status == ImportStatus.PENDING:
|
||||
session.rollback()
|
||||
return result.model_dump(mode="json"), 202
|
||||
session.commit()
|
||||
|
||||
# Inherit web app permission from original app
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@ -52,8 +52,9 @@ class AppImportApi(Resource):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = AppImportPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Create service with session
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# AppDslService performs internal commits for some creation paths, so use a plain
|
||||
# Session here instead of nesting it inside sessionmaker(...).begin().
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Import app
|
||||
account = current_user
|
||||
@ -69,6 +70,10 @@ class AppImportApi(Resource):
|
||||
icon_background=args.icon_background,
|
||||
app_id=args.app_id,
|
||||
)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
# update web app setting as private
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
@ -95,12 +100,15 @@ class AppImportConfirmApi(Resource):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# Create service with session
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
account = current_user
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED:
|
||||
@ -117,7 +125,7 @@ class AppImportCheckDependenciesApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
result = import_service.check_dependencies(app_model=app_model)
|
||||
|
||||
|
||||
@ -161,7 +161,7 @@ class ChatMessageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
@ -215,7 +215,7 @@ class ChatMessageStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@ -1,44 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_variable_fields import (
|
||||
conversation_variable_fields,
|
||||
paginated_conversation_variable_fields,
|
||||
)
|
||||
from fields._value_type_serializer import serialize_value_type
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID to filter variables")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ConversationVariablesQuery.__name__,
|
||||
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
|
||||
|
||||
# For nested models, need to replace nested dict with registered model
|
||||
paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy()
|
||||
paginated_conversation_variable_fields_copy["data"] = fields.List(
|
||||
fields.Nested(conversation_variable_model), attribute="data"
|
||||
)
|
||||
paginated_conversation_variable_model = console_ns.model(
|
||||
"PaginatedConversationVariable", paginated_conversation_variable_fields_copy
|
||||
class ConversationVariableResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
value_type: str
|
||||
value: str | None = None
|
||||
description: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return serialize_value_type(value)
|
||||
except Exception:
|
||||
return serialize_value_type({"value_type": value})
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def _normalize_value(cls, value: Any | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class PaginatedConversationVariableResponse(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[ConversationVariableResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ConversationVariablesQuery,
|
||||
ConversationVariableResponse,
|
||||
PaginatedConversationVariableResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -48,12 +90,15 @@ class ConversationVariablesApi(Resource):
|
||||
@console_ns.doc(description="Get conversation variables for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Conversation variables retrieved successfully",
|
||||
console_ns.models[PaginatedConversationVariableResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
@marshal_with(paginated_conversation_variable_model)
|
||||
def get(self, app_model):
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
@ -72,17 +117,22 @@ class ConversationVariablesApi(Resource):
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
rows = session.scalars(stmt).all()
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"limit": page_size,
|
||||
"total": len(rows),
|
||||
"has_more": False,
|
||||
"data": [
|
||||
{
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
**row.to_variable().model_dump(),
|
||||
}
|
||||
for row in rows
|
||||
],
|
||||
}
|
||||
response = PaginatedConversationVariableResponse.model_validate(
|
||||
{
|
||||
"page": page,
|
||||
"limit": page_size,
|
||||
"total": len(rows),
|
||||
"has_more": False,
|
||||
"data": [
|
||||
ConversationVariableResponse.model_validate(
|
||||
{
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
**row.to_variable().model_dump(),
|
||||
}
|
||||
)
|
||||
for row in rows
|
||||
],
|
||||
}
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@ -26,13 +26,13 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
|
||||
|
||||
|
||||
class MCPServerUpdatePayload(BaseModel):
|
||||
id: str = Field(..., description="Server ID")
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
|
||||
status: str | None = Field(default=None, description="Server status")
|
||||
|
||||
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, func, select
|
||||
@ -25,10 +26,21 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from fields.base import ResponseModel
|
||||
from fields.conversation_fields import (
|
||||
AgentThought,
|
||||
ConversationAnnotation,
|
||||
ConversationAnnotationHitHistory,
|
||||
Feedback,
|
||||
JSONValue,
|
||||
MessageFile,
|
||||
format_files_contained,
|
||||
to_timestamp,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
@ -98,6 +110,51 @@ class SuggestedQuestionsResponse(BaseModel):
|
||||
data: list[str] = Field(description="Suggested question")
|
||||
|
||||
|
||||
class MessageDetailResponse(ResponseModel):
|
||||
id: str
|
||||
conversation_id: str
|
||||
inputs: dict[str, JSONValue]
|
||||
query: str
|
||||
message: JSONValue | None = None
|
||||
message_tokens: int | None = None
|
||||
answer: str = Field(validation_alias="re_sign_file_url_answer")
|
||||
answer_tokens: int | None = None
|
||||
provider_response_latency: float | None = None
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
feedbacks: list[Feedback] = Field(default_factory=list)
|
||||
workflow_run_id: str | None = None
|
||||
annotation: ConversationAnnotation | None = None
|
||||
annotation_hit_history: ConversationAnnotationHitHistory | None = None
|
||||
created_at: int | None = None
|
||||
agent_thoughts: list[AgentThought] = Field(default_factory=list)
|
||||
message_files: list[MessageFile] = Field(default_factory=list)
|
||||
extra_contents: list[ExecutionExtraContentDomainModel] = Field(default_factory=list)
|
||||
metadata: JSONValue | None = Field(default=None, validation_alias="message_metadata_dict")
|
||||
status: str
|
||||
error: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
|
||||
return format_files_contained(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class MessageInfiniteScrollPaginationResponse(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[MessageDetailResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ChatMessagesQuery,
|
||||
@ -105,124 +162,8 @@ register_schema_models(
|
||||
FeedbackExportQuery,
|
||||
AnnotationCountResponse,
|
||||
SuggestedQuestionsResponse,
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model(
|
||||
"SimpleAccount",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"email": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
message_file_model = console_ns.model(
|
||||
"MessageFile",
|
||||
{
|
||||
"id": fields.String,
|
||||
"filename": fields.String,
|
||||
"type": fields.String,
|
||||
"url": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"size": fields.Integer,
|
||||
"transfer_method": fields.String,
|
||||
"belongs_to": fields.String(default="user"),
|
||||
"upload_file_id": fields.String(default=None),
|
||||
},
|
||||
)
|
||||
|
||||
agent_thought_model = console_ns.model(
|
||||
"AgentThought",
|
||||
{
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
},
|
||||
)
|
||||
|
||||
# Models that depend on simple_account_model
|
||||
feedback_model = console_ns.model(
|
||||
"Feedback",
|
||||
{
|
||||
"rating": fields.String,
|
||||
"content": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
},
|
||||
)
|
||||
|
||||
annotation_model = console_ns.model(
|
||||
"Annotation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"content": fields.String,
|
||||
"account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
annotation_hit_history_model = console_ns.model(
|
||||
"AnnotationHitHistory",
|
||||
{
|
||||
"annotation_id": fields.String(attribute="id"),
|
||||
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
# Message detail model that depends on multiple models
|
||||
message_detail_model = console_ns.model(
|
||||
"MessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": fields.Raw,
|
||||
"message_tokens": fields.Integer,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"answer_tokens": fields.Integer,
|
||||
"provider_response_latency": fields.Float,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"feedbacks": fields.List(fields.Nested(feedback_model)),
|
||||
"workflow_run_id": fields.String,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"extra_contents": fields.List(fields.Raw),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Message infinite scroll pagination model
|
||||
message_infinite_scroll_pagination_model = console_ns.model(
|
||||
"MessageInfiniteScrollPagination",
|
||||
{
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_detail_model)),
|
||||
},
|
||||
MessageDetailResponse,
|
||||
MessageInfiniteScrollPaginationResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -232,13 +173,12 @@ class ChatMessageListApi(Resource):
|
||||
@console_ns.doc(description="Get chat messages for a conversation with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
|
||||
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[MessageInfiniteScrollPaginationResponse.__name__])
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
@ -298,7 +238,10 @@ class ChatMessageListApi(Resource):
|
||||
history_messages = list(reversed(history_messages))
|
||||
attach_message_extra_contents(history_messages)
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
||||
return MessageInfiniteScrollPaginationResponse.model_validate(
|
||||
InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more),
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
|
||||
@ -393,7 +336,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
message_id = str(message_id)
|
||||
@ -468,13 +411,12 @@ class MessageApi(Resource):
|
||||
@console_ns.doc("get_message")
|
||||
@console_ns.doc(description="Get message details by ID")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@console_ns.response(200, "Message retrieved successfully", message_detail_model)
|
||||
@console_ns.response(200, "Message retrieved successfully", console_ns.models[MessageDetailResponse.__name__])
|
||||
@console_ns.response(404, "Message not found")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(message_detail_model)
|
||||
def get(self, app_model, message_id: str):
|
||||
message_id = str(message_id)
|
||||
|
||||
@ -486,4 +428,4 @@ class MessageApi(Resource):
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
attach_message_extra_contents([message])
|
||||
return message
|
||||
return MessageDetailResponse.model_validate(message, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from graphon.enums import NodeType
|
||||
from graphon.file import File
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
@ -206,7 +206,7 @@ class DraftWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
@ -226,7 +226,7 @@ class DraftWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@console_ns.doc("sync_draft_workflow")
|
||||
@console_ns.doc(description="Sync draft workflow configuration")
|
||||
@console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
|
||||
@ -310,7 +310,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
@ -356,7 +356,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@ -432,7 +432,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@ -534,7 +534,7 @@ class AdvancedChatDraftHumanInputFormPreviewApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@ -563,7 +563,7 @@ class AdvancedChatDraftHumanInputFormRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@ -718,7 +718,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, task_id: str):
|
||||
"""
|
||||
@ -746,7 +746,7 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
@ -792,7 +792,7 @@ class PublishedWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
@ -810,7 +810,7 @@ class PublishedWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
@ -854,7 +854,7 @@ class DefaultBlockConfigsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@ -876,7 +876,7 @@ class DefaultBlockConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App, block_type: str):
|
||||
"""
|
||||
@ -941,8 +941,7 @@ class PublishedAllWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_pagination_model)
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@ -970,9 +969,10 @@ class PublishedAllWorkflowApi(Resource):
|
||||
user_id=user_id,
|
||||
named_only=named_only,
|
||||
)
|
||||
serialized_workflows = marshal(workflows, workflow_fields_copy)
|
||||
|
||||
return {
|
||||
"items": workflows,
|
||||
"items": serialized_workflows,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"has_more": has_more,
|
||||
@ -990,7 +990,7 @@ class DraftWorkflowRestoreApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, workflow_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -1028,7 +1028,7 @@ class WorkflowByIdApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def patch(self, app_model: App, workflow_id: str):
|
||||
@ -1068,7 +1068,7 @@ class WorkflowByIdApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
@ -1103,7 +1103,7 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
srv = WorkflowService()
|
||||
|
||||
@ -1,27 +1,26 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import (
|
||||
build_workflow_app_log_pagination_model,
|
||||
build_workflow_archived_log_pagination_model,
|
||||
)
|
||||
from fields.base import ResponseModel
|
||||
from fields.end_user_fields import SimpleEndUser
|
||||
from fields.member_fields import SimpleAccount
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowAppLogQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
|
||||
@ -58,13 +57,113 @@ class WorkflowAppLogQuery(BaseModel):
|
||||
raise ValueError("Invalid boolean value for detail")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
class WorkflowRunForLogResponse(ResponseModel):
|
||||
id: str
|
||||
version: str | None = None
|
||||
status: str | None = None
|
||||
triggered_from: str | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
exceptions_count: int | None = None
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
|
||||
workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns)
|
||||
@field_validator("status", mode="before")
|
||||
@classmethod
|
||||
def _normalize_status(cls, value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class WorkflowRunForArchivedLogResponse(ResponseModel):
|
||||
id: str
|
||||
status: str | None = None
|
||||
triggered_from: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
|
||||
@field_validator("status", mode="before")
|
||||
@classmethod
|
||||
def _normalize_status(cls, value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
|
||||
class WorkflowAppLogPartialResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_run: WorkflowRunForLogResponse | None = None
|
||||
details: Any = None
|
||||
created_from: str | None = None
|
||||
created_by_role: str | None = None
|
||||
created_by_account: SimpleAccount | None = None
|
||||
created_by_end_user: SimpleEndUser | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class WorkflowArchivedLogPartialResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_run: WorkflowRunForArchivedLogResponse | None = None
|
||||
trigger_metadata: Any = None
|
||||
created_by_account: SimpleAccount | None = None
|
||||
created_by_end_user: SimpleEndUser | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class WorkflowAppLogPaginationResponse(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[WorkflowAppLogPartialResponse]
|
||||
|
||||
|
||||
class WorkflowArchivedLogPaginationResponse(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[WorkflowArchivedLogPartialResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
WorkflowAppLogQuery,
|
||||
WorkflowRunForLogResponse,
|
||||
WorkflowRunForArchivedLogResponse,
|
||||
WorkflowAppLogPartialResponse,
|
||||
WorkflowArchivedLogPartialResponse,
|
||||
WorkflowAppLogPaginationResponse,
|
||||
WorkflowArchivedLogPaginationResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
|
||||
@ -73,12 +172,15 @@ class WorkflowAppLogApi(Resource):
|
||||
@console_ns.doc(description="Get workflow application execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow app logs retrieved successfully",
|
||||
console_ns.models[WorkflowAppLogPaginationResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_app_log_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow app logs
|
||||
@ -87,7 +189,7 @@ class WorkflowAppLogApi(Resource):
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
@ -102,7 +204,9 @@ class WorkflowAppLogApi(Resource):
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
return WorkflowAppLogPaginationResponse.model_validate(
|
||||
workflow_app_log_pagination, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-archived-logs")
|
||||
@ -111,12 +215,15 @@ class WorkflowArchivedLogApi(Resource):
|
||||
@console_ns.doc(description="Get workflow archived execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||
@console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow archived logs retrieved successfully",
|
||||
console_ns.models[WorkflowArchivedLogPaginationResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_archived_log_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow archived logs
|
||||
@ -124,7 +231,7 @@ class WorkflowArchivedLogApi(Resource):
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
@ -132,4 +239,6 @@ class WorkflowArchivedLogApi(Resource):
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
return WorkflowArchivedLogPaginationResponse.model_validate(
|
||||
workflow_app_log_pagination, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@ -1,322 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.member_fields import AccountWithRole
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowCommentCreatePayload(BaseModel):
|
||||
position_x: float = Field(..., description="Comment X position")
|
||||
position_y: float = Field(..., description="Comment Y position")
|
||||
content: str = Field(..., description="Comment content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Comment content")
|
||||
position_x: float | None = Field(default=None, description="Comment X position")
|
||||
position_y: float | None = Field(default=None, description="Comment Y position")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyCreatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentMentionUsersResponse(BaseModel):
|
||||
users: list[AccountWithRole] = Field(description="Mentionable users")
|
||||
|
||||
|
||||
for model in (
|
||||
WorkflowCommentCreatePayload,
|
||||
WorkflowCommentUpdatePayload,
|
||||
WorkflowCommentReplyCreatePayload,
|
||||
WorkflowCommentReplyUpdatePayload,
|
||||
):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
for model in (AccountWithRole, WorkflowCommentMentionUsersResponse):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
|
||||
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
|
||||
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
|
||||
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
|
||||
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
|
||||
workflow_comment_reply_create_model = console_ns.model(
|
||||
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
|
||||
)
|
||||
workflow_comment_reply_update_model = console_ns.model(
|
||||
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
|
||||
)
|
||||
workflow_comment_mention_users_model = console_ns.models[WorkflowCommentMentionUsersResponse.__name__]
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments")
|
||||
class WorkflowCommentListApi(Resource):
|
||||
"""API for listing and creating workflow comments."""
|
||||
|
||||
@console_ns.doc("list_workflow_comments")
|
||||
@console_ns.doc(description="Get all comments for a workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_basic_model, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
|
||||
@console_ns.doc("create_workflow_comment")
|
||||
@console_ns.doc(description="Create a new workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_create_model)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
class WorkflowCommentDetailApi(Resource):
|
||||
"""API for managing individual workflow comments."""
|
||||
|
||||
@console_ns.doc("get_workflow_comment")
|
||||
@console_ns.doc(description="Get a specific workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_detail_model)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
@console_ns.doc("update_workflow_comment")
|
||||
@console_ns.doc(description="Update a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_update_model)
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@console_ns.doc("delete_workflow_comment")
|
||||
@console_ns.doc(description="Delete a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(204, "Comment deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
class WorkflowCommentResolveApi(Resource):
|
||||
"""API for resolving and reopening workflow comments."""
|
||||
|
||||
@console_ns.doc("resolve_workflow_comment")
|
||||
@console_ns.doc(description="Resolve a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_resolve_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
class WorkflowCommentReplyApi(Resource):
|
||||
"""API for managing comment replies."""
|
||||
|
||||
@console_ns.doc("create_workflow_comment_reply")
|
||||
@console_ns.doc(description="Add a reply to a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyCreatePayload.__name__])
|
||||
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_create_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_reply(
|
||||
comment_id=comment_id,
|
||||
content=payload.content,
|
||||
created_by=current_user.id,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
|
||||
class WorkflowCommentReplyDetailApi(Resource):
|
||||
"""API for managing individual comment replies."""
|
||||
|
||||
@console_ns.doc("update_workflow_comment_reply")
|
||||
@console_ns.doc(description="Update a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_update_model)
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
reply_id=reply_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
@console_ns.doc("delete_workflow_comment_reply")
|
||||
@console_ns.doc(description="Delete a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.response(204, "Reply deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
class WorkflowCommentMentionUsersApi(Resource):
|
||||
"""API for getting mentionable users for workflow comments."""
|
||||
|
||||
@console_ns.doc("workflow_comment_mention_users")
|
||||
@console_ns.doc(description="Get all users in current tenant for mentions")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Mentionable users retrieved successfully", workflow_comment_mention_users_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = WorkflowCommentMentionUsersResponse(users=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
@ -216,7 +216,7 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
@ -36,7 +36,7 @@ from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowR
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
from services.workflow_run_service import WorkflowRunListArgs, WorkflowRunService
|
||||
|
||||
|
||||
def _build_backstage_input_url(form_token: str | None) -> str | None:
|
||||
@ -207,14 +207,18 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(advanced_chat_workflow_run_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
args: WorkflowRunListArgs = {"limit": args_model.limit}
|
||||
if args_model.last_id is not None:
|
||||
args["last_id"] = args_model.last_id
|
||||
if args_model.status is not None:
|
||||
args["status"] = args_model.status
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
@ -305,7 +309,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@ -349,14 +353,18 @@ class WorkflowRunListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
args: WorkflowRunListArgs = {"limit": args_model.limit}
|
||||
if args_model.last_id is not None:
|
||||
args["last_id"] = args_model.last_id
|
||||
if args_model.status is not None:
|
||||
args["status"] = args_model.status
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
@ -397,7 +405,7 @@ class WorkflowRunCountApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@ -434,7 +442,7 @@ class WorkflowRunDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_detail_model)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
@ -458,7 +466,7 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_list_model)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
|
||||
@ -1,16 +1,17 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.common.schema import register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_user, login_required
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import Account, App, AppMode
|
||||
@ -21,15 +22,6 @@ from ..app.wraps import get_app_model
|
||||
from ..wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields)
|
||||
|
||||
triggers_list_fields_copy = triggers_list_fields.copy()
|
||||
triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model))
|
||||
triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy)
|
||||
|
||||
webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields)
|
||||
|
||||
|
||||
class Parser(BaseModel):
|
||||
@ -41,10 +33,52 @@ class ParserEnable(BaseModel):
|
||||
enable_trigger: bool
|
||||
|
||||
|
||||
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class WorkflowTriggerResponse(ResponseModel):
|
||||
id: str
|
||||
trigger_type: str
|
||||
title: str
|
||||
node_id: str
|
||||
provider_name: str
|
||||
icon: str
|
||||
status: str
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
@field_validator("id", "trigger_type", "title", "node_id", "provider_name", "icon", "status", mode="before")
|
||||
@classmethod
|
||||
def _normalize_string_fields(cls, value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
class WorkflowTriggerListResponse(ResponseModel):
|
||||
data: list[WorkflowTriggerResponse]
|
||||
|
||||
|
||||
class WebhookTriggerResponse(ResponseModel):
|
||||
id: str
|
||||
webhook_id: str
|
||||
webhook_url: str
|
||||
webhook_debug_url: str
|
||||
node_id: str
|
||||
created_at: datetime | None = None
|
||||
|
||||
@field_validator("id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", mode="before")
|
||||
@classmethod
|
||||
def _normalize_string_fields(cls, value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
Parser,
|
||||
ParserEnable,
|
||||
WorkflowTriggerResponse,
|
||||
WorkflowTriggerListResponse,
|
||||
WebhookTriggerResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -57,14 +91,14 @@ class WebhookTriggerApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(webhook_trigger_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__])
|
||||
def get(self, app_model: App):
|
||||
"""Get webhook trigger for a node"""
|
||||
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
node_id = args.node_id
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = session.scalar(
|
||||
select(WorkflowWebhookTrigger)
|
||||
@ -78,7 +112,7 @@ class WebhookTriggerApi(Resource):
|
||||
if not webhook_trigger:
|
||||
raise NotFound("Webhook trigger not found for this node")
|
||||
|
||||
return webhook_trigger
|
||||
return WebhookTriggerResponse.model_validate(webhook_trigger, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/triggers")
|
||||
@ -89,13 +123,13 @@ class AppTriggersApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(triggers_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerListResponse.__name__])
|
||||
def get(self, app_model: App):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
@ -118,7 +152,9 @@ class AppTriggersApi(Resource):
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return {"data": triggers}
|
||||
return WorkflowTriggerListResponse.model_validate({"data": triggers}, from_attributes=True).model_dump(
|
||||
mode="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
|
||||
@ -129,7 +165,7 @@ class AppTriggerEnableApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerResponse.__name__])
|
||||
def post(self, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
args = ParserEnable.model_validate(console_ns.payload)
|
||||
@ -160,4 +196,4 @@ class AppTriggerEnableApi(Resource):
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return trigger
|
||||
return WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@ -40,7 +42,7 @@ class ActivatePayload(BaseModel):
|
||||
|
||||
class ActivationCheckResponse(BaseModel):
|
||||
is_valid: bool = Field(description="Whether token is valid")
|
||||
data: dict | None = Field(default=None, description="Activation data if valid")
|
||||
data: dict[str, Any] | None = Field(default=None, description="Activation data if valid")
|
||||
|
||||
|
||||
class ActivationResponse(BaseModel):
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import logging
|
||||
|
||||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@ -42,12 +45,13 @@ from libs.token import (
|
||||
)
|
||||
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.entities.auth_entities import LoginPayloadBase
|
||||
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
|
||||
from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginPayload(LoginPayloadBase):
|
||||
@ -91,10 +95,12 @@ class LoginApi(Resource):
|
||||
normalized_email = request_email.lower()
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
|
||||
if is_login_error_rate_limit:
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.LOGIN_RATE_LIMITED)
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
invite_token = args.invite_token
|
||||
@ -110,14 +116,20 @@ class LoginApi(Resource):
|
||||
invitee_email = data.get("email") if data else None
|
||||
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
|
||||
if invitee_email_normalized != normalized_email:
|
||||
_log_console_login_failure(
|
||||
email=normalized_email,
|
||||
reason=LoginFailureReason.INVALID_INVITATION_EMAIL,
|
||||
)
|
||||
raise InvalidEmailError()
|
||||
account = _authenticate_account_with_case_fallback(
|
||||
request_email, normalized_email, args.password, invite_token
|
||||
)
|
||||
except services.errors.account.AccountLoginError:
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError as exc:
|
||||
AccountService.add_login_error_rate_limit(normalized_email)
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
|
||||
raise AuthenticationFailedError() from exc
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@ -240,20 +252,27 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args.token)
|
||||
if token_data is None:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != user_email:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args.code:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
|
||||
raise EmailCodeError()
|
||||
|
||||
AccountService.revoke_email_code_login_token(args.token)
|
||||
try:
|
||||
account = _get_account_with_case_fallback(original_email)
|
||||
except Unauthorized as exc:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError() from exc
|
||||
except AccountRegisterError:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@ -279,6 +298,7 @@ class EmailCodeLoginApi(Resource):
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
except AccountRegisterError:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
@ -336,3 +356,12 @@ def _authenticate_account_with_case_fallback(
|
||||
if original_email == normalized_email:
|
||||
raise
|
||||
return AccountService.authenticate(normalized_email, password, invite_token)
|
||||
|
||||
|
||||
def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None:
|
||||
logger.warning(
|
||||
"Console login failed: email=%s reason=%s ip_address=%s",
|
||||
email,
|
||||
reason,
|
||||
extract_remote_ip(request),
|
||||
)
|
||||
|
||||
@ -1026,7 +1026,7 @@ class DocumentMetadataApi(DocumentResource):
|
||||
|
||||
if not isinstance(doc_metadata, dict):
|
||||
raise ValueError("doc_metadata must be a dictionary.")
|
||||
metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
|
||||
metadata_schema: dict[str, Any] = cast(dict[str, Any], DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
|
||||
|
||||
document.doc_metadata = {}
|
||||
if doc_type == "others":
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
from flask_restx import Resource, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from fields.hit_testing_fields import (
|
||||
child_chunk_fields,
|
||||
document_fields,
|
||||
files_fields,
|
||||
hit_testing_record_fields,
|
||||
segment_fields,
|
||||
)
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
@ -18,39 +18,92 @@ from ..wraps import (
|
||||
setup_required,
|
||||
)
|
||||
|
||||
register_schema_model(console_ns, HitTestingPayload)
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
class HitTestingDocument(ResponseModel):
|
||||
id: str | None = None
|
||||
data_source_type: str | None = None
|
||||
name: str | None = None
|
||||
doc_type: str | None = None
|
||||
doc_metadata: Any | None = None
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
document_model = _get_or_create_model("HitTestingDocument", document_fields)
|
||||
class HitTestingSegment(ResponseModel):
|
||||
id: str | None = None
|
||||
position: int | None = None
|
||||
document_id: str | None = None
|
||||
content: str | None = None
|
||||
sign_content: str | None = None
|
||||
answer: str | None = None
|
||||
word_count: int | None = None
|
||||
tokens: int | None = None
|
||||
keywords: list[str] = Field(default_factory=list)
|
||||
index_node_id: str | None = None
|
||||
index_node_hash: str | None = None
|
||||
hit_count: int | None = None
|
||||
enabled: bool | None = None
|
||||
disabled_at: int | None = None
|
||||
disabled_by: str | None = None
|
||||
status: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
indexing_at: int | None = None
|
||||
completed_at: int | None = None
|
||||
error: str | None = None
|
||||
stopped_at: int | None = None
|
||||
document: HitTestingDocument | None = None
|
||||
|
||||
segment_fields_copy = segment_fields.copy()
|
||||
segment_fields_copy["document"] = fields.Nested(document_model)
|
||||
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
|
||||
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
|
||||
files_model = _get_or_create_model("HitTestingFile", files_fields)
|
||||
|
||||
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
|
||||
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
|
||||
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
|
||||
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
|
||||
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
|
||||
class HitTestingChildChunk(ResponseModel):
|
||||
id: str | None = None
|
||||
content: str | None = None
|
||||
position: int | None = None
|
||||
score: float | None = None
|
||||
|
||||
# Response model for hit testing API
|
||||
hit_testing_response_fields = {
|
||||
"query": fields.String,
|
||||
"records": fields.List(fields.Nested(hit_testing_record_model)),
|
||||
}
|
||||
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
|
||||
|
||||
class HitTestingFile(ResponseModel):
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
size: int | None = None
|
||||
extension: str | None = None
|
||||
mime_type: str | None = None
|
||||
source_url: str | None = None
|
||||
|
||||
|
||||
class HitTestingRecord(ResponseModel):
|
||||
segment: HitTestingSegment | None = None
|
||||
child_chunks: list[HitTestingChildChunk] = Field(default_factory=list)
|
||||
score: float | None = None
|
||||
tsne_position: Any | None = None
|
||||
files: list[HitTestingFile] = Field(default_factory=list)
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class HitTestingResponse(ResponseModel):
|
||||
query: str
|
||||
records: list[HitTestingRecord] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
HitTestingPayload,
|
||||
HitTestingDocument,
|
||||
HitTestingSegment,
|
||||
HitTestingChildChunk,
|
||||
HitTestingFile,
|
||||
HitTestingRecord,
|
||||
HitTestingResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
@ -59,7 +112,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
|
||||
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Hit testing completed successfully",
|
||||
model=console_ns.models[HitTestingResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
@ -74,4 +131,4 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json")
|
||||
|
||||
@ -1,21 +1,24 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import and_, select
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_database import db
|
||||
from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields
|
||||
from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from models.model import IconType
|
||||
from services.account_service import TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
@ -36,22 +39,97 @@ class InstalledAppsListQuery(BaseModel):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
app_model = get_or_create_model("InstalledAppInfo", app_fields)
|
||||
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
|
||||
if icon is None or icon_type is None:
|
||||
return None
|
||||
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
|
||||
if icon_type_value.lower() != IconType.IMAGE:
|
||||
return None
|
||||
return file_helpers.get_signed_file_url(icon)
|
||||
|
||||
installed_app_fields_copy = installed_app_fields.copy()
|
||||
installed_app_fields_copy["app"] = fields.Nested(app_model)
|
||||
installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy)
|
||||
|
||||
installed_app_list_fields_copy = installed_app_list_fields.copy()
|
||||
installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model))
|
||||
installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy)
|
||||
def _safe_primitive(value: Any) -> Any:
|
||||
if value is None or isinstance(value, (str, int, float, bool, datetime)):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
class InstalledAppInfoResponse(ResponseModel):
|
||||
id: str
|
||||
name: str | None = None
|
||||
mode: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
use_icon_as_answer_icon: bool | None = None
|
||||
|
||||
@field_validator("mode", "icon_type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_like(cls, value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
|
||||
class InstalledAppResponse(ResponseModel):
|
||||
id: str
|
||||
app: InstalledAppInfoResponse
|
||||
app_owner_tenant_id: str
|
||||
is_pinned: bool
|
||||
last_used_at: int | None = None
|
||||
editable: bool
|
||||
uninstallable: bool
|
||||
|
||||
@field_validator("app", mode="before")
|
||||
@classmethod
|
||||
def _normalize_app(cls, value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
return {
|
||||
"id": _safe_primitive(getattr(value, "id", "")) or "",
|
||||
"name": _safe_primitive(getattr(value, "name", None)),
|
||||
"mode": _safe_primitive(getattr(value, "mode", None)),
|
||||
"icon_type": _safe_primitive(getattr(value, "icon_type", None)),
|
||||
"icon": _safe_primitive(getattr(value, "icon", None)),
|
||||
"icon_background": _safe_primitive(getattr(value, "icon_background", None)),
|
||||
"use_icon_as_answer_icon": _safe_primitive(getattr(value, "use_icon_as_answer_icon", None)),
|
||||
}
|
||||
|
||||
@field_validator("last_used_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class InstalledAppListResponse(ResponseModel):
|
||||
installed_apps: list[InstalledAppResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
InstalledAppCreatePayload,
|
||||
InstalledAppUpdatePayload,
|
||||
InstalledAppsListQuery,
|
||||
InstalledAppInfoResponse,
|
||||
InstalledAppResponse,
|
||||
InstalledAppListResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps")
|
||||
class InstalledAppsListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(installed_app_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
|
||||
def get(self):
|
||||
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@ -125,7 +203,9 @@ class InstalledAppsListApi(Resource):
|
||||
)
|
||||
)
|
||||
|
||||
return {"installed_apps": installed_app_list}
|
||||
return InstalledAppListResponse.model_validate(
|
||||
{"installed_apps": installed_app_list}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -1,66 +1,83 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import AppIconUrlField
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import build_icon_url
|
||||
from libs.login import current_user, login_required
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
app_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"mode": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_type": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"icon_background": fields.String,
|
||||
}
|
||||
|
||||
app_model = get_or_create_model("RecommendedAppInfo", app_fields)
|
||||
|
||||
recommended_app_fields = {
|
||||
"app": fields.Nested(app_model, attribute="app"),
|
||||
"app_id": fields.String,
|
||||
"description": fields.String(attribute="description"),
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"category": fields.String,
|
||||
"position": fields.Integer,
|
||||
"is_listed": fields.Boolean,
|
||||
"can_trial": fields.Boolean,
|
||||
}
|
||||
|
||||
recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields)
|
||||
|
||||
recommended_app_list_fields = {
|
||||
"recommended_apps": fields.List(fields.Nested(recommended_app_model)),
|
||||
"categories": fields.List(fields.String),
|
||||
}
|
||||
|
||||
recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields)
|
||||
|
||||
|
||||
class RecommendedAppsQuery(BaseModel):
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
RecommendedAppsQuery.__name__,
|
||||
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
class RecommendedAppInfoResponse(ResponseModel):
|
||||
id: str
|
||||
name: str | None = None
|
||||
mode: str | None = None
|
||||
icon: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon_background: str | None = None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_enum_like(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
@field_validator("mode", "icon_type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> str | None:
|
||||
return cls._normalize_enum_like(value)
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
|
||||
class RecommendedAppResponse(ResponseModel):
|
||||
app: RecommendedAppInfoResponse | None = None
|
||||
app_id: str
|
||||
description: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
category: str | None = None
|
||||
position: int | None = None
|
||||
is_listed: bool | None = None
|
||||
can_trial: bool | None = None
|
||||
|
||||
|
||||
class RecommendedAppListResponse(ResponseModel):
|
||||
recommended_apps: list[RecommendedAppResponse]
|
||||
categories: list[str]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
RecommendedAppsQuery,
|
||||
RecommendedAppInfoResponse,
|
||||
RecommendedAppResponse,
|
||||
RecommendedAppListResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps")
|
||||
class RecommendedAppListApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(recommended_app_list_model)
|
||||
def get(self):
|
||||
# language args
|
||||
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
@ -72,7 +89,10 @@ class RecommendedAppListApi(Resource):
|
||||
else:
|
||||
language_prefix = languages[0]
|
||||
|
||||
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
|
||||
return RecommendedAppListResponse.model_validate(
|
||||
RecommendedAppService.get_recommended_apps_and_categories(language_prefix),
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps/<uuid:app_id>")
|
||||
|
||||
@ -169,6 +169,7 @@ console_ns.schema_model(
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
|
||||
def post(self, trial_app):
|
||||
"""
|
||||
@ -210,6 +211,7 @@ class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
@ -290,7 +292,6 @@ class TrialChatApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialMessageSuggestedQuestionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def get(self, trial_app, message_id):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
@ -470,7 +471,6 @@ class TrialCompletionApi(TrialAppResource):
|
||||
class TrialSitApi(Resource):
|
||||
"""Resource for trial app sites."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial(None)
|
||||
def get(self, app_model):
|
||||
"""Retrieve app site info.
|
||||
@ -492,7 +492,6 @@ class TrialSitApi(Resource):
|
||||
class TrialAppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial(None)
|
||||
def get(self, app_model):
|
||||
"""Retrieve app parameters."""
|
||||
@ -521,7 +520,6 @@ class TrialAppParameterApi(Resource):
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial(None)
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
def get(self, app_model):
|
||||
@ -534,7 +532,6 @@ class AppApi(Resource):
|
||||
|
||||
|
||||
class AppWorkflowApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial(None)
|
||||
@marshal_with(workflow_model)
|
||||
def get(self, app_model):
|
||||
@ -547,7 +544,6 @@ class AppWorkflowApi(Resource):
|
||||
|
||||
|
||||
class DatasetListApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial(None)
|
||||
def get(self, app_model):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
|
||||
@ -1,15 +1,18 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
||||
from ..common.schema import register_schema_models
|
||||
from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, setup_required
|
||||
|
||||
@ -24,12 +27,52 @@ class APIBasedExtensionPayload(BaseModel):
|
||||
api_key: str = Field(description="API key for authentication")
|
||||
|
||||
|
||||
register_schema_models(console_ns, APIBasedExtensionPayload)
|
||||
class CodeBasedExtensionResponse(ResponseModel):
|
||||
module: str = Field(description="Module name")
|
||||
data: Any = Field(description="Extension data")
|
||||
|
||||
|
||||
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
|
||||
def _mask_api_key(api_key: str) -> str:
|
||||
if not api_key:
|
||||
return api_key
|
||||
if len(api_key) <= 8:
|
||||
return api_key[0] + "******" + api_key[-1]
|
||||
return api_key[:3] + "******" + api_key[-3:]
|
||||
|
||||
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class APIBasedExtensionResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
api_endpoint: str
|
||||
api_key: str
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def _normalize_api_key(cls, value: str) -> str:
|
||||
return _mask_api_key(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
register_schema_models(console_ns, APIBasedExtensionPayload, CodeBasedExtensionResponse, APIBasedExtensionResponse)
|
||||
console_ns.schema_model(
|
||||
"APIBasedExtensionListResponse",
|
||||
TypeAdapter(list[APIBasedExtensionResponse]).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def _serialize_api_based_extension(extension: APIBasedExtension) -> dict[str, Any]:
|
||||
return APIBasedExtensionResponse.model_validate(extension, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/code-based-extension")
|
||||
@ -40,10 +83,7 @@ class CodeBasedExtensionAPI(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"CodeBasedExtensionResponse",
|
||||
{"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")},
|
||||
),
|
||||
console_ns.models[CodeBasedExtensionResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -51,30 +91,34 @@ class CodeBasedExtensionAPI(Resource):
|
||||
def get(self):
|
||||
query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)}
|
||||
return CodeBasedExtensionResponse(
|
||||
module=query.module,
|
||||
data=CodeBasedExtensionService.get_code_based_extension(query.module),
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/api-based-extension")
|
||||
class APIBasedExtensionAPI(Resource):
|
||||
@console_ns.doc("get_api_based_extensions")
|
||||
@console_ns.doc(description="Get all API-based extensions for current tenant")
|
||||
@console_ns.response(200, "Success", api_based_extension_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models["APIBasedExtensionListResponse"])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_model)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
return [
|
||||
_serialize_api_based_extension(extension)
|
||||
for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
]
|
||||
|
||||
@console_ns.doc("create_api_based_extension")
|
||||
@console_ns.doc(description="Create a new API-based extension")
|
||||
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
|
||||
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
|
||||
@console_ns.response(201, "Extension created successfully", console_ns.models[APIBasedExtensionResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_model)
|
||||
def post(self):
|
||||
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@ -86,7 +130,7 @@ class APIBasedExtensionAPI(Resource):
|
||||
api_key=payload.api_key,
|
||||
)
|
||||
|
||||
return APIBasedExtensionService.save(extension_data)
|
||||
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data))
|
||||
|
||||
|
||||
@console_ns.route("/api-based-extension/<uuid:id>")
|
||||
@ -94,26 +138,26 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
@console_ns.doc("get_api_based_extension")
|
||||
@console_ns.doc(description="Get API-based extension by ID")
|
||||
@console_ns.doc(params={"id": "Extension ID"})
|
||||
@console_ns.response(200, "Success", api_based_extension_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[APIBasedExtensionResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_model)
|
||||
def get(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
return _serialize_api_based_extension(
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
)
|
||||
|
||||
@console_ns.doc("update_api_based_extension")
|
||||
@console_ns.doc(description="Update API-based extension")
|
||||
@console_ns.doc(params={"id": "Extension ID"})
|
||||
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
|
||||
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
|
||||
@console_ns.response(200, "Extension updated successfully", console_ns.models[APIBasedExtensionResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_model)
|
||||
def post(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@ -128,7 +172,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
if payload.api_key != HIDDEN_VALUE:
|
||||
extension_data_from_db.api_key = payload.api_key
|
||||
|
||||
return APIBasedExtensionService.save(extension_data_from_db)
|
||||
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data_from_db))
|
||||
|
||||
@console_ns.doc("delete_api_based_extension")
|
||||
@console_ns.doc(description="Delete API-based extension")
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import TypedDict
|
||||
|
||||
from flask import request
|
||||
@ -13,6 +14,14 @@ from services.billing_service import BillingService
|
||||
_FALLBACK_LANG = "en-US"
|
||||
|
||||
|
||||
class NotificationLangContent(TypedDict, total=False):
|
||||
lang: str
|
||||
title: str
|
||||
subtitle: str
|
||||
body: str
|
||||
titlePicUrl: str
|
||||
|
||||
|
||||
class NotificationItemDict(TypedDict):
|
||||
notification_id: str | None
|
||||
frequency: str | None
|
||||
@ -28,9 +37,11 @@ class NotificationResponseDict(TypedDict):
|
||||
notifications: list[NotificationItemDict]
|
||||
|
||||
|
||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||
def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent:
|
||||
"""Return the single LangContent for *lang*, falling back to English."""
|
||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||
return (
|
||||
contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent())
|
||||
)
|
||||
|
||||
|
||||
class DismissNotificationPayload(BaseModel):
|
||||
@ -71,7 +82,7 @@ class NotificationApi(Resource):
|
||||
|
||||
notifications: list[NotificationItemDict] = []
|
||||
for notification in result.get("notifications") or []:
|
||||
contents: dict = notification.get("contents") or {}
|
||||
contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {}
|
||||
lang_content = _pick_lang_content(contents, lang)
|
||||
item: NotificationItemDict = {
|
||||
"notification_id": notification.get("notificationId"),
|
||||
|
||||
@ -1 +0,0 @@
|
||||
|
||||
@ -1,119 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from flask import Request as FlaskRequest
|
||||
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository
|
||||
from services.account_service import AccountService
|
||||
from services.workflow_collaboration_service import WorkflowCollaborationService
|
||||
|
||||
repository = WorkflowCollaborationRepository()
|
||||
collaboration_service = WorkflowCollaborationService(repository, sio)
|
||||
|
||||
|
||||
def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]:
|
||||
return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event))
|
||||
|
||||
|
||||
@_sio_on("connect")
|
||||
def socket_connect(sid, environ, auth):
|
||||
"""
|
||||
WebSocket connect event, do authentication here.
|
||||
"""
|
||||
try:
|
||||
request_environ = FlaskRequest(environ)
|
||||
token = extract_access_token(request_environ)
|
||||
except Exception:
|
||||
logging.exception("Failed to extract token")
|
||||
token = None
|
||||
|
||||
if not token:
|
||||
logging.warning("Socket connect rejected: missing token (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
try:
|
||||
decoded = PassportService().verify(token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if not user:
|
||||
logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
if not user.has_edit_permission:
|
||||
logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
|
||||
collaboration_service.save_session(sid, user)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logging.exception("Socket authentication failed")
|
||||
return False
|
||||
|
||||
|
||||
@_sio_on("user_connect")
|
||||
def handle_user_connect(sid, data):
|
||||
"""
|
||||
Handle user connect event. Each session (tab) is treated as an independent collaborator.
|
||||
"""
|
||||
workflow_id = data.get("workflow_id")
|
||||
if not workflow_id:
|
||||
return {"msg": "workflow_id is required"}, 400
|
||||
|
||||
result = collaboration_service.register_session(workflow_id, sid)
|
||||
if not result:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
user_id, is_leader = result
|
||||
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
|
||||
|
||||
|
||||
@_sio_on("disconnect")
|
||||
def handle_disconnect(sid):
|
||||
"""
|
||||
Handle session disconnect event. Remove the specific session from online users.
|
||||
"""
|
||||
collaboration_service.disconnect_session(sid)
|
||||
|
||||
|
||||
@_sio_on("collaboration_event")
|
||||
def handle_collaboration_event(sid, data):
|
||||
"""
|
||||
Handle general collaboration events, include:
|
||||
1. mouse_move
|
||||
2. vars_and_features_update
|
||||
3. sync_request (ask leader to update graph)
|
||||
4. app_state_update
|
||||
5. mcp_server_update
|
||||
6. workflow_update
|
||||
7. comments_update
|
||||
8. node_panel_presence
|
||||
9. skill_file_active
|
||||
10. skill_sync_request
|
||||
11. skill_resync_request
|
||||
"""
|
||||
return collaboration_service.relay_collaboration_event(sid, data)
|
||||
|
||||
|
||||
@_sio_on("graph_event")
|
||||
def handle_graph_event(sid, data):
|
||||
"""
|
||||
Handle graph events - simple broadcast relay.
|
||||
"""
|
||||
return collaboration_service.relay_graph_event(sid, data)
|
||||
|
||||
|
||||
@_sio_on("skill_event")
|
||||
def handle_skill_event(sid, data):
|
||||
"""
|
||||
Handle skill events - simple broadcast relay.
|
||||
"""
|
||||
return collaboration_service.relay_skill_event(sid, data)
|
||||
@ -1,13 +1,14 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import TagType
|
||||
from services.tag_service import (
|
||||
@ -18,17 +19,6 @@ from services.tag_service import (
|
||||
UpdateTagPayload,
|
||||
)
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"type": fields.String,
|
||||
"binding_count": fields.String,
|
||||
}
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
|
||||
|
||||
class TagBasePayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
@ -52,12 +42,36 @@ class TagListQueryParam(BaseModel):
|
||||
keyword: str | None = Field(None, description="Search keyword")
|
||||
|
||||
|
||||
class TagResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str | None = None
|
||||
binding_count: str | None = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def normalize_type(cls, value: TagType | str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, TagType):
|
||||
return value.value
|
||||
return value
|
||||
|
||||
@field_validator("binding_count", mode="before")
|
||||
@classmethod
|
||||
def normalize_binding_count(cls, value: int | str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
TagListQueryParam,
|
||||
TagResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -69,14 +83,18 @@ class TagListApi(Resource):
|
||||
@console_ns.doc(
|
||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||
)
|
||||
@marshal_with(dataset_tag_fields)
|
||||
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
raw_args = request.args.to_dict()
|
||||
param = TagListQueryParam.model_validate(raw_args)
|
||||
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
|
||||
|
||||
return tags, 200
|
||||
serialized_tags = [
|
||||
TagResponse.model_validate(tag, from_attributes=True).model_dump(mode="json") for tag in tags
|
||||
]
|
||||
|
||||
return serialized_tags, 200
|
||||
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@ -91,7 +109,9 @@ class TagListApi(Resource):
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
response = TagResponse.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
).model_dump(mode="json")
|
||||
|
||||
return response, 200
|
||||
|
||||
@ -114,7 +134,9 @@ class TagUpdateDeleteApi(Resource):
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
response = TagResponse.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
).model_dump(mode="json")
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
@ -35,22 +35,24 @@ def plugin_permission_required(
|
||||
return view(*args, **kwargs)
|
||||
|
||||
if install_required:
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
match permission.install_permission:
|
||||
case TenantPluginPermission.InstallPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
|
||||
pass
|
||||
case TenantPluginPermission.InstallPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
case TenantPluginPermission.InstallPermission.EVERYONE:
|
||||
pass
|
||||
|
||||
if debug_required:
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
match permission.debug_permission:
|
||||
case TenantPluginPermission.DebugPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
|
||||
pass
|
||||
case TenantPluginPermission.DebugPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
case TenantPluginPermission.DebugPermission.EVERYONE:
|
||||
pass
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import pytz
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -37,9 +37,10 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import Account as AccountResponse
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus, InvitationCodeStatus
|
||||
@ -174,21 +175,61 @@ reg(CheckEmailUniquePayload)
|
||||
register_schema_models(console_ns, AccountResponse)
|
||||
|
||||
|
||||
def _serialize_account(account) -> dict:
|
||||
def _serialize_account(account) -> dict[str, Any]:
|
||||
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
integrate_fields = {
|
||||
"provider": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"is_bound": fields.Boolean,
|
||||
"link": fields.String,
|
||||
}
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
integrate_model = console_ns.model("AccountIntegrate", integrate_fields)
|
||||
integrate_list_model = console_ns.model(
|
||||
"AccountIntegrateList",
|
||||
{"data": fields.List(fields.Nested(integrate_model))},
|
||||
|
||||
class AccountIntegrateResponse(ResponseModel):
|
||||
provider: str
|
||||
created_at: int | None = None
|
||||
is_bound: bool
|
||||
link: str | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AccountIntegrateListResponse(ResponseModel):
|
||||
data: list[AccountIntegrateResponse]
|
||||
|
||||
|
||||
class EducationVerifyResponse(ResponseModel):
|
||||
token: str | None = None
|
||||
|
||||
|
||||
class EducationStatusResponse(ResponseModel):
|
||||
result: bool | None = None
|
||||
is_student: bool | None = None
|
||||
expire_at: int | None = None
|
||||
allow_refresh: bool | None = None
|
||||
|
||||
@field_validator("expire_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_expire_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class EducationAutocompleteResponse(ResponseModel):
|
||||
data: list[str] = Field(default_factory=list)
|
||||
curr_page: int | None = None
|
||||
has_next: bool | None = None
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AccountIntegrateResponse,
|
||||
AccountIntegrateListResponse,
|
||||
EducationVerifyResponse,
|
||||
EducationStatusResponse,
|
||||
EducationAutocompleteResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -359,7 +400,7 @@ class AccountIntegrateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@ -395,7 +436,9 @@ class AccountIntegrateApi(Resource):
|
||||
}
|
||||
)
|
||||
|
||||
return {"data": integrate_data}
|
||||
return AccountIntegrateListResponse(
|
||||
data=[AccountIntegrateResponse.model_validate(item) for item in integrate_data]
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/delete/verify")
|
||||
@ -447,31 +490,22 @@ class AccountDeleteUpdateFeedbackApi(Resource):
|
||||
|
||||
@console_ns.route("/account/education/verify")
|
||||
class EducationVerifyApi(Resource):
|
||||
verify_fields = {
|
||||
"token": fields.String,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@marshal_with(verify_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
return BillingService.EducationIdentity.verify(account.id, account.email)
|
||||
return EducationVerifyResponse.model_validate(
|
||||
BillingService.EducationIdentity.verify(account.id, account.email) or {}
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/education")
|
||||
class EducationApi(Resource):
|
||||
status_fields = {
|
||||
"result": fields.Boolean,
|
||||
"is_student": fields.Boolean,
|
||||
"expire_at": TimestampField,
|
||||
"allow_refresh": fields.Boolean,
|
||||
}
|
||||
|
||||
@console_ns.expect(console_ns.models[EducationActivatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -491,37 +525,33 @@ class EducationApi(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@marshal_with(status_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
res = BillingService.EducationIdentity.status(account.id)
|
||||
res = BillingService.EducationIdentity.status(account.id) or {}
|
||||
# convert expire_at to UTC timestamp from isoformat
|
||||
if res and "expire_at" in res:
|
||||
res["expire_at"] = datetime.fromisoformat(res["expire_at"]).astimezone(pytz.utc)
|
||||
return res
|
||||
return EducationStatusResponse.model_validate(res).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/education/autocomplete")
|
||||
class EducationAutoCompleteApi(Resource):
|
||||
data_fields = {
|
||||
"data": fields.List(fields.String),
|
||||
"curr_page": fields.Integer,
|
||||
"has_next": fields.Boolean,
|
||||
}
|
||||
|
||||
@console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@marshal_with(data_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationAutocompleteResponse.__name__])
|
||||
def get(self):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
args = EducationAutocompleteQuery.model_validate(payload)
|
||||
|
||||
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
|
||||
return EducationAutocompleteResponse.model_validate(
|
||||
BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) or {}
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email")
|
||||
|
||||
@ -1,67 +0,0 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.app_dsl_service import AppDslService
|
||||
|
||||
|
||||
class DSLPredictRequest(BaseModel):
|
||||
app_id: str
|
||||
current_node_id: str
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/dsl/predict")
|
||||
class DSLPredictApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, _ = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
args = DSLPredictRequest.model_validate(request.get_json())
|
||||
|
||||
app_id: str = args.app_id
|
||||
current_node_id: str = args.current_node_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
app = session.query(App).filter_by(id=app_id).first()
|
||||
workflow = session.query(Workflow).filter_by(app_id=app_id, version=Workflow.VERSION_DRAFT).first()
|
||||
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
try:
|
||||
i = 0
|
||||
for node_id, _ in workflow.walk_nodes():
|
||||
if node_id == current_node_id:
|
||||
break
|
||||
i += 1
|
||||
|
||||
dsl = yaml.safe_load(AppDslService.export_dsl(app_model=app))
|
||||
|
||||
response = httpx.post(
|
||||
"http://spark-832c:8000/predict",
|
||||
json={"graph_data": dsl, "source_node_index": i},
|
||||
)
|
||||
return {
|
||||
"nodes": json.loads(response.json()),
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
@ -465,7 +465,7 @@ class ModelProviderModelDisableApi(Resource):
|
||||
class ParserValidate(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict
|
||||
credentials: dict[str, Any]
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields, marshal
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
@ -26,6 +27,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
|
||||
@ -58,6 +60,37 @@ class WorkspaceInfoPayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class TenantInfoResponse(ResponseModel):
|
||||
id: str
|
||||
name: str | None = None
|
||||
plan: str | None = None
|
||||
status: str | None = None
|
||||
created_at: int | None = None
|
||||
role: str | None = None
|
||||
in_trial: bool | None = None
|
||||
trial_end_reason: str | None = None
|
||||
custom_config: dict | None = None
|
||||
trial_credits: int | None = None
|
||||
trial_credits_used: int | None = None
|
||||
next_credit_reset_date: int | None = None
|
||||
|
||||
@field_validator("plan", "status", "trial_end_reason", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_like(cls, value):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None):
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
@ -66,6 +99,7 @@ reg(WorkspaceListQuery)
|
||||
reg(SwitchWorkspacePayload)
|
||||
reg(WorkspaceCustomConfigPayload)
|
||||
reg(WorkspaceInfoPayload)
|
||||
reg(TenantInfoResponse)
|
||||
|
||||
provider_fields = {
|
||||
"provider_name": fields.String,
|
||||
@ -180,7 +214,7 @@ class TenantApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(tenant_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__])
|
||||
def post(self):
|
||||
if request.path == "/info":
|
||||
logger.warning("Deprecated URL /info was used.")
|
||||
@ -200,7 +234,13 @@ class TenantApi(Resource):
|
||||
else:
|
||||
raise Unauthorized("workspace is archived")
|
||||
|
||||
return WorkspaceService.get_tenant_info(tenant), 200
|
||||
return (
|
||||
TenantInfoResponse.model_validate(
|
||||
WorkspaceService.get_tenant_info(tenant),
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/switch")
|
||||
|
||||
@ -20,7 +20,7 @@ from models.account import AccountStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
from services.operation_service import OperationService
|
||||
from services.operation_service import OperationService, UtmInfo
|
||||
|
||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
|
||||
@ -205,7 +205,7 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
|
||||
if utm_info:
|
||||
utm_info_dict: dict = json.loads(utm_info)
|
||||
utm_info_dict: UtmInfo = json.loads(utm_info)
|
||||
OperationService.record_utm(current_tenant_id, utm_info_dict)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -1,80 +0,0 @@
|
||||
"""Token-based file proxy controller for storage operations.
|
||||
|
||||
This controller handles file download and upload operations using opaque UUID tokens.
|
||||
The token maps to the real storage key in Redis, so the actual storage path is never
|
||||
exposed in the URL.
|
||||
|
||||
Routes:
|
||||
GET /files/storage-files/{token} - Download a file
|
||||
PUT /files/storage-files/{token} - Upload a file
|
||||
|
||||
The operation type (download/upload) is determined by the ticket stored in Redis,
|
||||
not by the HTTP method. This ensures a download ticket cannot be used for upload
|
||||
and vice versa.
|
||||
"""
|
||||
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Forbidden, NotFound, RequestEntityTooLarge
|
||||
|
||||
from controllers.files import files_ns
|
||||
from extensions.ext_storage import storage
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
|
||||
@files_ns.route("/storage-files/<string:token>")
|
||||
class StorageFilesApi(Resource):
|
||||
"""Handle file operations through token-based URLs."""
|
||||
|
||||
def get(self, token: str):
|
||||
"""Download a file using a token.
|
||||
|
||||
The ticket must have op="download", otherwise returns 403.
|
||||
"""
|
||||
ticket = StorageTicketService.get_ticket(token)
|
||||
if ticket is None:
|
||||
raise Forbidden("Invalid or expired token")
|
||||
|
||||
if ticket.op != "download":
|
||||
raise Forbidden("This token is not valid for download")
|
||||
|
||||
try:
|
||||
generator = storage.load_stream(ticket.storage_key)
|
||||
except FileNotFoundError:
|
||||
raise NotFound("File not found")
|
||||
|
||||
filename = ticket.filename or ticket.storage_key.rsplit("/", 1)[-1]
|
||||
encoded_filename = quote(filename)
|
||||
|
||||
return Response(
|
||||
generator,
|
||||
mimetype="application/octet-stream",
|
||||
direct_passthrough=True,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
},
|
||||
)
|
||||
|
||||
def put(self, token: str):
|
||||
"""Upload a file using a token.
|
||||
|
||||
The ticket must have op="upload", otherwise returns 403.
|
||||
If the request body exceeds max_bytes, returns 413.
|
||||
"""
|
||||
ticket = StorageTicketService.get_ticket(token)
|
||||
if ticket is None:
|
||||
raise Forbidden("Invalid or expired token")
|
||||
|
||||
if ticket.op != "upload":
|
||||
raise Forbidden("This token is not valid for upload")
|
||||
|
||||
content = request.get_data()
|
||||
|
||||
if ticket.max_bytes is not None and len(content) > ticket.max_bytes:
|
||||
raise RequestEntityTooLarge(f"Upload exceeds maximum size of {ticket.max_bytes} bytes")
|
||||
|
||||
storage.save(ticket.storage_key, content)
|
||||
|
||||
return Response(status=204)
|
||||
@ -9,7 +9,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.wraps import setup_required
|
||||
@ -56,7 +56,7 @@ class EnterpriseAppDSLImport(Resource):
|
||||
|
||||
account.set_tenant_id(workspace_id)
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
dsl_service = AppDslService(session)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
@ -65,6 +65,10 @@ class EnterpriseAppDSLImport(Resource):
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import Any, Union
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from graphon.variables.input_entities import VariableEntity
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@ -158,14 +158,20 @@ class MCPAppApi(Resource):
|
||||
except ValidationError as e:
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
|
||||
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
|
||||
def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]:
|
||||
"""Convert raw user input form to VariableEntity objects"""
|
||||
return [self._create_variable_entity(item) for item in raw_form]
|
||||
|
||||
def _create_variable_entity(self, item: dict) -> VariableEntity:
|
||||
def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity:
|
||||
"""Create a single VariableEntity from raw form item"""
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
variable_type_raw: str = item.get("type", "") or list(item.keys())[0]
|
||||
try:
|
||||
variable_type = VariableEntityType(variable_type_raw)
|
||||
except ValueError as e:
|
||||
raise MCPRequestError(
|
||||
mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}"
|
||||
) from e
|
||||
variable = item[variable_type_raw]
|
||||
|
||||
return VariableEntity(
|
||||
type=variable_type,
|
||||
@ -178,7 +184,7 @@ class MCPAppApi(Resource):
|
||||
json_schema=variable.get("json_schema"),
|
||||
)
|
||||
|
||||
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
"""Parse and validate MCP request"""
|
||||
try:
|
||||
return mcp_types.ClientRequest.model_validate(args)
|
||||
|
||||
@ -194,7 +194,7 @@ class ChatApi(Resource):
|
||||
Supports conversation management and both blocking and streaming response modes.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
|
||||
@ -258,7 +258,7 @@ class ChatStopApi(Resource):
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""Stop a running chat message generation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppTaskService.stop_task(
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.variables.types import SegmentType
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
@ -14,14 +16,12 @@ from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields._value_type_serializer import serialize_value_type
|
||||
from fields.base import ResponseModel
|
||||
from fields.conversation_fields import (
|
||||
ConversationInfiniteScrollPagination,
|
||||
SimpleConversation,
|
||||
)
|
||||
from fields.conversation_variable_fields import (
|
||||
build_conversation_variable_infinite_scroll_pagination_model,
|
||||
build_conversation_variable_model,
|
||||
)
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
@ -70,12 +70,70 @@ class ConversationVariableUpdatePayload(BaseModel):
|
||||
value: Any
|
||||
|
||||
|
||||
class ConversationVariableResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
value_type: str
|
||||
value: str | None = None
|
||||
description: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
@classmethod
|
||||
def normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return str(SegmentType(value).exposed_type().value)
|
||||
except ValueError:
|
||||
return value
|
||||
try:
|
||||
return serialize_value_type(value)
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
try:
|
||||
return serialize_value_type({"value_type": value})
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
value_attr = getattr(value, "value", None)
|
||||
if value_attr is not None:
|
||||
return str(value_attr)
|
||||
return str(value)
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def normalize_value(cls, value: Any | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[ConversationVariableResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
ConversationListQuery,
|
||||
ConversationRenamePayload,
|
||||
ConversationVariablesQuery,
|
||||
ConversationVariableUpdatePayload,
|
||||
ConversationVariableResponse,
|
||||
ConversationVariableInfiniteScrollPaginationResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -98,7 +156,7 @@ class ConversationApi(Resource):
|
||||
Supports pagination using last_id and limit parameters.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
query_args = ConversationListQuery.model_validate(request.args.to_dict())
|
||||
@ -142,7 +200,7 @@ class ConversationDetailApi(Resource):
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Delete a specific conversation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -171,7 +229,7 @@ class ConversationRenameApi(Resource):
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Rename a conversation or auto-generate a name."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -204,8 +262,12 @@ class ConversationVariablesApi(Resource):
|
||||
404: "Conversation not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Variables retrieved successfully",
|
||||
service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__],
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns))
|
||||
def get(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""List all variables for a conversation.
|
||||
|
||||
@ -213,7 +275,7 @@ class ConversationVariablesApi(Resource):
|
||||
"""
|
||||
# conversational variable only for chat app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -222,9 +284,12 @@ class ConversationVariablesApi(Resource):
|
||||
last_id = str(query_args.last_id) if query_args.last_id else None
|
||||
|
||||
try:
|
||||
return ConversationService.get_conversational_variable(
|
||||
pagination = ConversationService.get_conversational_variable(
|
||||
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
|
||||
)
|
||||
return ConversationVariableInfiniteScrollPaginationResponse.model_validate(
|
||||
pagination, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@ -243,8 +308,12 @@ class ConversationVariableDetailApi(Resource):
|
||||
404: "Conversation or variable not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Variable updated successfully",
|
||||
service_api_ns.models[ConversationVariableResponse.__name__],
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns))
|
||||
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
|
||||
"""Update a conversation variable's value.
|
||||
|
||||
@ -252,7 +321,7 @@ class ConversationVariableDetailApi(Resource):
|
||||
The value must match the variable's expected type.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -261,9 +330,10 @@ class ConversationVariableDetailApi(Resource):
|
||||
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
return ConversationService.update_conversation_variable(
|
||||
variable = ConversationService.update_conversation_variable(
|
||||
app_model, conversation_id, variable_id, end_user, payload.value
|
||||
)
|
||||
return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json")
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationVariableNotExistsError:
|
||||
|
||||
@ -53,7 +53,7 @@ class MessageListApi(Resource):
|
||||
Retrieves messages with pagination support using first_id.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
query_args = MessageListQuery.model_validate(request.args.to_dict())
|
||||
@ -158,7 +158,7 @@ class MessageSuggestedApi(Resource):
|
||||
"""
|
||||
message_id = str(message_id)
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
try:
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from flask_restx import Resource, fields
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
@ -33,9 +35,10 @@ from core.errors.error import (
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from fields.base import ResponseModel
|
||||
from fields.end_user_fields import SimpleEndUser
|
||||
from fields.member_fields import SimpleAccount
|
||||
from libs import helper
|
||||
from libs.helper import OptionalTimestampField, TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
@ -65,38 +68,142 @@ class WorkflowLogQuery(BaseModel):
|
||||
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
def _enum_value(value):
|
||||
return getattr(value, "value", value)
|
||||
|
||||
|
||||
class WorkflowRunStatusField(fields.Raw):
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
return obj.status.value
|
||||
return _enum_value(obj.status)
|
||||
|
||||
|
||||
class WorkflowRunOutputsField(fields.Raw):
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
if obj.status == WorkflowExecutionStatus.PAUSED:
|
||||
status = _enum_value(obj.status)
|
||||
if status == WorkflowExecutionStatus.PAUSED.value:
|
||||
return {}
|
||||
|
||||
outputs = obj.outputs_dict
|
||||
return outputs or {}
|
||||
|
||||
|
||||
workflow_run_fields = {
|
||||
"id": fields.String,
|
||||
"workflow_id": fields.String,
|
||||
"status": WorkflowRunStatusField,
|
||||
"inputs": fields.Raw,
|
||||
"outputs": WorkflowRunOutputsField,
|
||||
"error": fields.String,
|
||||
"total_steps": fields.Integer,
|
||||
"total_tokens": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
"finished_at": OptionalTimestampField,
|
||||
"elapsed_time": fields.Float,
|
||||
}
|
||||
class WorkflowRunResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
inputs: dict | list | str | int | float | bool | None = None
|
||||
outputs: dict = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
total_steps: int | None = None
|
||||
total_tokens: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
elapsed_time: float | int | None = None
|
||||
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
def build_workflow_run_model(api_or_ns: Namespace):
|
||||
"""Build the workflow run model for the API or Namespace."""
|
||||
return api_or_ns.model("WorkflowRun", workflow_run_fields)
|
||||
class WorkflowRunForLogResponse(ResponseModel):
|
||||
id: str
|
||||
version: str | None = None
|
||||
status: str | None = None
|
||||
triggered_from: str | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float | int | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
exceptions_count: int | None = None
|
||||
|
||||
@field_validator("status", "triggered_from", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum(cls, value):
|
||||
return _enum_value(value)
|
||||
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowAppLogPartialResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_run: WorkflowRunForLogResponse | None = None
|
||||
details: dict | list | str | int | float | bool | None = None
|
||||
created_from: str | None = None
|
||||
created_by_role: str | None = None
|
||||
created_by_account: SimpleAccount | None = None
|
||||
created_by_end_user: SimpleEndUser | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_from", "created_by_role", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum(cls, value):
|
||||
return _enum_value(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowAppLogPaginationResponse(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[WorkflowAppLogPartialResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
WorkflowRunResponse,
|
||||
WorkflowRunForLogResponse,
|
||||
WorkflowAppLogPartialResponse,
|
||||
WorkflowAppLogPaginationResponse,
|
||||
)
|
||||
|
||||
|
||||
def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict:
|
||||
status = _enum_value(workflow_run.status)
|
||||
raw_outputs = workflow_run.outputs_dict
|
||||
if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
|
||||
outputs: dict = {}
|
||||
elif isinstance(raw_outputs, dict):
|
||||
outputs = raw_outputs
|
||||
elif isinstance(raw_outputs, Mapping):
|
||||
outputs = dict(raw_outputs)
|
||||
else:
|
||||
outputs = {}
|
||||
return WorkflowRunResponse.model_validate(
|
||||
{
|
||||
"id": workflow_run.id,
|
||||
"workflow_id": workflow_run.workflow_id,
|
||||
"status": status,
|
||||
"inputs": workflow_run.inputs,
|
||||
"outputs": outputs,
|
||||
"error": workflow_run.error,
|
||||
"total_steps": workflow_run.total_steps,
|
||||
"total_tokens": workflow_run.total_tokens,
|
||||
"created_at": workflow_run.created_at,
|
||||
"finished_at": workflow_run.finished_at,
|
||||
"elapsed_time": workflow_run.elapsed_time,
|
||||
}
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
def _serialize_workflow_log_pagination(pagination) -> dict:
|
||||
return WorkflowAppLogPaginationResponse.model_validate(pagination, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@service_api_ns.route("/workflows/run/<string:workflow_run_id>")
|
||||
@ -112,7 +219,11 @@ class WorkflowRunDetailApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Workflow run details retrieved successfully",
|
||||
service_api_ns.models[WorkflowRunResponse.__name__],
|
||||
)
|
||||
def get(self, app_model: App, workflow_run_id: str):
|
||||
"""Get a workflow task running detail.
|
||||
|
||||
@ -133,7 +244,7 @@ class WorkflowRunDetailApi(Resource):
|
||||
)
|
||||
if not workflow_run:
|
||||
raise NotFound("Workflow run not found.")
|
||||
return workflow_run
|
||||
return _serialize_workflow_run(workflow_run)
|
||||
|
||||
|
||||
@service_api_ns.route("/workflows/run")
|
||||
@ -299,7 +410,11 @@ class WorkflowAppLogApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns))
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Logs retrieved successfully",
|
||||
service_api_ns.models[WorkflowAppLogPaginationResponse.__name__],
|
||||
)
|
||||
def get(self, app_model: App):
|
||||
"""Get workflow app logs.
|
||||
|
||||
@ -327,4 +442,4 @@ class WorkflowAppLogApi(Resource):
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
return _serialize_workflow_log_pagination(workflow_app_log_pagination)
|
||||
|
||||
@ -33,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
|
||||
"""Marshal a single segment and enrich it with summary content."""
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
|
||||
"""Marshal multiple segments and enrich them with summary content (batch query)."""
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries: dict = {}
|
||||
summaries: dict[str, str | None] = {}
|
||||
if segment_ids:
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||
|
||||
result = []
|
||||
result: list[dict[str, Any]] = []
|
||||
for segment in segments:
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
result.append(segment_dict)
|
||||
return result
|
||||
|
||||
@ -5,6 +5,7 @@ Web App Human Input Form APIs.
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
@ -58,10 +59,19 @@ def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
class FormDefinitionPayload(TypedDict):
|
||||
form_content: Any
|
||||
inputs: Any
|
||||
resolved_default_values: dict[str, str]
|
||||
user_actions: Any
|
||||
expiration_time: int
|
||||
site: NotRequired[dict]
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
|
||||
"""Return the form payload (optionally with site) as a JSON response."""
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
payload: FormDefinitionPayload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
@ -92,7 +102,7 @@ class HumanInputFormApi(Resource):
|
||||
_FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
# TODO(QuantumGhost): forbid submision for form tokens
|
||||
# TODO(QuantumGhost): forbid submission for form tokens
|
||||
# that are only for console.
|
||||
form = service.get_form_by_token(form_token)
|
||||
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import logging
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from jwt import InvalidTokenError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@ -20,7 +23,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import decode_jwt_token
|
||||
from libs.helper import EmailStr
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.passport import PassportService
|
||||
from libs.password import valid_password
|
||||
from libs.token import (
|
||||
@ -29,9 +32,11 @@ from libs.token import (
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
from services.app_service import AppService
|
||||
from services.entities.auth_entities import LoginPayloadBase
|
||||
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginPayload(LoginPayloadBase):
|
||||
@field_validator("password")
|
||||
@ -76,14 +81,18 @@ class LoginApi(Resource):
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
payload = LoginPayload.model_validate(web_ns.payload or {})
|
||||
normalized_email = payload.email.lower()
|
||||
|
||||
try:
|
||||
account = WebAppAuthService.authenticate(payload.email, payload.password)
|
||||
except services.errors.account.AccountLoginError:
|
||||
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
|
||||
raise AuthenticationFailedError()
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
@ -212,21 +221,30 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(payload.token)
|
||||
if token_data is None:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
if normalized_token_email != user_email:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != payload.code:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
|
||||
raise EmailCodeError()
|
||||
|
||||
WebAppAuthService.revoke_email_code_login_token(payload.token)
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
try:
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
except Unauthorized as exc:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError() from exc
|
||||
if not account:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
@ -234,3 +252,12 @@ class EmailCodeLoginApi(Resource):
|
||||
response = make_response({"result": "success", "data": {"access_token": token}})
|
||||
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
|
||||
return response
|
||||
|
||||
|
||||
def _log_web_login_failure(*, email: str, reason: LoginFailureReason) -> None:
|
||||
logger.warning(
|
||||
"Web login failed: email=%s reason=%s ip_address=%s",
|
||||
email,
|
||||
reason,
|
||||
extract_remote_ip(request),
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
@ -103,21 +104,23 @@ class PassportResource(Resource):
|
||||
return response
|
||||
|
||||
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None):
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None:
|
||||
"""
|
||||
Decode the enterprise user session from the Authorization header.
|
||||
"""
|
||||
if not jwt_token:
|
||||
return None
|
||||
|
||||
decoded = PassportService().verify(jwt_token)
|
||||
decoded: dict[str, Any] = PassportService().verify(jwt_token)
|
||||
source = decoded.get("token_source")
|
||||
if not source or source != "webapp_login_token":
|
||||
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
|
||||
return decoded
|
||||
|
||||
|
||||
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
|
||||
def exchange_token_for_existing_web_user(
|
||||
app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType
|
||||
):
|
||||
"""
|
||||
Exchange a token for an existing web user session.
|
||||
"""
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_restx import fields, marshal, marshal_with
|
||||
from sqlalchemy import select
|
||||
@ -113,12 +113,12 @@ class AppSiteInfo:
|
||||
}
|
||||
|
||||
|
||||
def serialize_site(site: Site) -> dict:
|
||||
def serialize_site(site: Site) -> dict[str, Any]:
|
||||
"""Serialize Site model using the same schema as AppSiteApi."""
|
||||
return cast(dict, marshal(site, AppSiteApi.site_fields))
|
||||
return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields))
|
||||
|
||||
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
|
||||
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
|
||||
return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))
|
||||
|
||||
@ -1,399 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any, cast
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult, ExecutionContext
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentAppRunner(BaseAgentRunner):
|
||||
|
||||
@property
|
||||
def model_features(self) -> list[ModelFeature]:
|
||||
llm_model = cast(LargeLanguageModel, self.model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(self.model_instance.model_name, self.model_instance.credentials)
|
||||
if not model_schema:
|
||||
return []
|
||||
return list(model_schema.features or [])
|
||||
|
||||
def build_execution_context(self) -> ExecutionContext:
|
||||
return ExecutionContext(
|
||||
user_id=self.user_id,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
conversation_id=self.conversation.id if self.conversation else None,
|
||||
message_id=self.message.id if self.message else None,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
def _create_tool_invoke_hook(self, message: Message):
|
||||
"""
|
||||
Create a tool invoke hook that uses ToolEngine.agent_invoke.
|
||||
This hook handles file creation and returns proper meta information.
|
||||
"""
|
||||
# Get trace manager from app generate entity
|
||||
trace_manager = self.application_generate_entity.trace_manager
|
||||
|
||||
def tool_invoke_hook(
|
||||
tool: Tool, tool_args: dict[str, Any], tool_name: str
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""Hook that uses agent_invoke for proper file and meta handling."""
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
|
||||
# Publish files and track IDs
|
||||
for message_file_id in message_files:
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
self._current_message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, message_files, tool_invoke_meta
|
||||
|
||||
return tool_invoke_hook
|
||||
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run Agent application
|
||||
"""
|
||||
self.query = query
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, _ = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
# Create tool invoke hook for agent_invoke
|
||||
tool_invoke_hook = self._create_tool_invoke_hook(message)
|
||||
|
||||
# Get instruction for ReAct strategy
|
||||
instruction = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
# Use factory to create appropriate strategy
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=self.model_features,
|
||||
model_instance=self.model_instance,
|
||||
tools=list(tool_instances.values()),
|
||||
files=list(self.files),
|
||||
max_iterations=app_config.agent.max_iteration,
|
||||
context=self.build_execution_context(),
|
||||
agent_strategy=self.config.strategy,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Initialize state variables
|
||||
current_agent_thought_id: str | None = None
|
||||
has_published_thought = False
|
||||
current_tool_name: str | None = None
|
||||
self._current_message_file_ids: list[str] = []
|
||||
|
||||
# organize prompt messages
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
|
||||
# Run strategy
|
||||
generator = strategy.run(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Consume generator and collect result
|
||||
result: AgentResult | None = None
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
output = next(generator)
|
||||
except StopIteration as e:
|
||||
# Generator finished, get the return value
|
||||
result = e.value
|
||||
break
|
||||
|
||||
if isinstance(output, LLMResultChunk):
|
||||
# Handle LLM chunk
|
||||
if current_agent_thought_id and not has_published_thought:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
has_published_thought = True
|
||||
|
||||
yield output
|
||||
|
||||
elif isinstance(output, AgentLog):
|
||||
# Handle Agent Log using log_type for type-safe dispatch
|
||||
if output.status == AgentLog.LogStatus.START:
|
||||
if output.log_type == AgentLog.LogType.ROUND:
|
||||
# Start of a new round
|
||||
message_file_ids: list[str] = []
|
||||
current_agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message="",
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
has_published_thought = False
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call start - extract data from structured fields
|
||||
current_tool_name = output.data.get("tool_name", "")
|
||||
tool_input = output.data.get("tool_args", {})
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_input=tool_input,
|
||||
thought=None,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.status == AgentLog.LogStatus.SUCCESS:
|
||||
if output.log_type == AgentLog.LogType.THOUGHT:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
thought_text = output.data.get("thought")
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=thought_text,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call finished
|
||||
tool_output = output.data.get("output")
|
||||
# Get meta from strategy output (now properly populated)
|
||||
tool_meta = output.data.get("meta")
|
||||
|
||||
# Wrap tool_meta with tool_name as key (required by agent_service)
|
||||
if tool_meta and current_tool_name:
|
||||
tool_meta = {current_tool_name: tool_meta}
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
observation=tool_output,
|
||||
tool_invoke_meta=tool_meta,
|
||||
answer=None,
|
||||
messages_ids=self._current_message_file_ids,
|
||||
)
|
||||
# Clear message file ids after saving
|
||||
self._current_message_file_ids = []
|
||||
current_tool_name = None
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.ROUND:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Round finished - save LLM usage and answer
|
||||
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
|
||||
llm_result = output.data.get("llm_result")
|
||||
final_answer = output.data.get("final_answer")
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=llm_result,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Re-raise any other exceptions
|
||||
raise
|
||||
|
||||
# Process final result
|
||||
if isinstance(result, AgentResult):
|
||||
final_answer = result.text
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
# Publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=self.model_instance.model_name,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=usage,
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
if not prompt_template:
|
||||
return prompt_messages or []
|
||||
|
||||
prompt_messages = prompt_messages or []
|
||||
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
|
||||
return prompt_messages
|
||||
|
||||
if not prompt_messages:
|
||||
return [SystemPromptMessage(content=prompt_template)]
|
||||
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
We need to remove the image messages from the prompt messages at the first iteration.
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
# For ReAct strategy, use the agent prompt template
|
||||
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
|
||||
prompt_template = self.config.prompt.first_prompt
|
||||
else:
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
return prompt_messages
|
||||
@ -1,5 +1,3 @@
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
@ -94,79 +92,3 @@ class AgentInvokeMessage(ToolInvokeMessage):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
"""Execution context containing trace and audit information.
|
||||
|
||||
Carries IDs and metadata needed for tracing, auditing, and correlation
|
||||
but not part of the core business logic.
|
||||
"""
|
||||
|
||||
user_id: str | None = None
|
||||
app_id: str | None = None
|
||||
conversation_id: str | None = None
|
||||
message_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext":
|
||||
return cls(user_id=user_id)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"message_id": self.message_id,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
def with_updates(self, **kwargs) -> "ExecutionContext":
|
||||
data = self.to_dict()
|
||||
data.update(kwargs)
|
||||
return ExecutionContext(**{k: v for k, v in data.items() if k in ExecutionContext.model_fields})
|
||||
|
||||
|
||||
class AgentLog(BaseModel):
|
||||
"""Structured log entry for agent execution tracing."""
|
||||
|
||||
class LogType(StrEnum):
|
||||
ROUND = "round"
|
||||
THOUGHT = "thought"
|
||||
TOOL_CALL = "tool_call"
|
||||
|
||||
class LogMetadata(StrEnum):
|
||||
STARTED_AT = "started_at"
|
||||
FINISHED_AT = "finished_at"
|
||||
ELAPSED_TIME = "elapsed_time"
|
||||
TOTAL_PRICE = "total_price"
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
PROVIDER = "provider"
|
||||
CURRENCY = "currency"
|
||||
LLM_USAGE = "llm_usage"
|
||||
ICON = "icon"
|
||||
ICON_DARK = "icon_dark"
|
||||
|
||||
class LogStatus(StrEnum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
label: str = Field(...)
|
||||
log_type: LogType = Field(...)
|
||||
parent_id: str | None = Field(default=None)
|
||||
error: str | None = Field(default=None)
|
||||
status: LogStatus = Field(...)
|
||||
data: Mapping[str, Any] = Field(...)
|
||||
metadata: Mapping[LogMetadata, Any] = Field(default={})
|
||||
|
||||
|
||||
class AgentResult(BaseModel):
|
||||
"""Agent execution result."""
|
||||
|
||||
text: str = Field(default="")
|
||||
files: list[Any] = Field(default_factory=list)
|
||||
usage: Any | None = Field(default=None)
|
||||
finish_reason: str | None = Field(default=None)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
|
||||
@ -11,7 +11,7 @@ from core.agent.entities import AgentScratchpadUnit
|
||||
class CotAgentOutputParser:
|
||||
@classmethod
|
||||
def handle_react_stream_output(
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any]
|
||||
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
|
||||
action_name = None
|
||||
|
||||
@ -1,19 +0,0 @@
|
||||
"""Agent patterns module.
|
||||
|
||||
This module provides different strategies for agent execution:
|
||||
- FunctionCallStrategy: Uses native function/tool calling
|
||||
- ReActStrategy: Uses ReAct (Reasoning + Acting) approach
|
||||
- StrategyFactory: Factory for creating strategies based on model features
|
||||
"""
|
||||
|
||||
from .base import AgentPattern
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
from .strategy_factory import StrategyFactory
|
||||
|
||||
__all__ = [
|
||||
"AgentPattern",
|
||||
"FunctionCallStrategy",
|
||||
"ReActStrategy",
|
||||
"StrategyFactory",
|
||||
]
|
||||
@ -1,506 +0,0 @@
|
||||
"""Base class for agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
|
||||
from core.model_manager import ModelInstance
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
# Type alias for tool invoke hook
|
||||
# Returns: (response_content, message_file_ids, tool_invoke_meta)
|
||||
ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]]
|
||||
|
||||
|
||||
class AgentPattern(ABC):
|
||||
"""Base class for agent execution strategies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
):
|
||||
"""Initialize the agent strategy."""
|
||||
self.model_instance = model_instance
|
||||
self.tools = tools
|
||||
self.context = context
|
||||
self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.files: list[File] = files
|
||||
self.tool_invoke_hook = tool_invoke_hook
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the agent strategy."""
|
||||
pass
|
||||
|
||||
def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None:
|
||||
"""Accumulate LLM usage statistics."""
|
||||
if not total_usage.get("usage"):
|
||||
# Create a copy to avoid modifying the original
|
||||
total_usage["usage"] = LLMUsage(
|
||||
prompt_tokens=delta_usage.prompt_tokens,
|
||||
prompt_unit_price=delta_usage.prompt_unit_price,
|
||||
prompt_price_unit=delta_usage.prompt_price_unit,
|
||||
prompt_price=delta_usage.prompt_price,
|
||||
completion_tokens=delta_usage.completion_tokens,
|
||||
completion_unit_price=delta_usage.completion_unit_price,
|
||||
completion_price_unit=delta_usage.completion_price_unit,
|
||||
completion_price=delta_usage.completion_price,
|
||||
total_tokens=delta_usage.total_tokens,
|
||||
total_price=delta_usage.total_price,
|
||||
currency=delta_usage.currency,
|
||||
latency=delta_usage.latency,
|
||||
)
|
||||
else:
|
||||
current: LLMUsage = total_usage["usage"]
|
||||
current.prompt_tokens += delta_usage.prompt_tokens
|
||||
current.completion_tokens += delta_usage.completion_tokens
|
||||
current.total_tokens += delta_usage.total_tokens
|
||||
current.prompt_price += delta_usage.prompt_price
|
||||
current.completion_price += delta_usage.completion_price
|
||||
current.total_price += delta_usage.total_price
|
||||
|
||||
def _extract_content(self, content: Any) -> str:
|
||||
"""Extract text content from message content."""
|
||||
if isinstance(content, list):
|
||||
# Content items are PromptMessageContentUnionTypes
|
||||
text_parts = []
|
||||
for c in content:
|
||||
# Check if it's a TextPromptMessageContent (which has data attribute)
|
||||
if isinstance(c, TextPromptMessageContent):
|
||||
text_parts.append(c.data)
|
||||
return "".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
def _has_tool_calls(self, chunk: LLMResultChunk) -> bool:
|
||||
"""Check if chunk contains tool calls."""
|
||||
# LLMResultChunk always has delta attribute
|
||||
return bool(chunk.delta.message and chunk.delta.message.tool_calls)
|
||||
|
||||
def _has_tool_calls_result(self, result: LLMResult) -> bool:
|
||||
"""Check if result contains tool calls (non-streaming)."""
|
||||
# LLMResult always has message attribute
|
||||
return bool(result.message and result.message.tool_calls)
|
||||
|
||||
def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from streaming chunk."""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
if chunk.delta.message and chunk.delta.message.tool_calls:
|
||||
for tool_call in chunk.delta.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from non-streaming result."""
|
||||
tool_calls = []
|
||||
if result.message and result.message.tool_calls:
|
||||
for tool_call in result.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_text_from_message(self, message: PromptMessage) -> str:
|
||||
"""Extract text content from a prompt message."""
|
||||
# PromptMessage always has content attribute
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from content list
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
text_parts.append(item.data)
|
||||
return " ".join(text_parts)
|
||||
return ""
|
||||
|
||||
def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]:
|
||||
"""Get metadata for a tool including provider and icon info."""
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {}
|
||||
if tool_instance.entity and tool_instance.entity.identity:
|
||||
identity = tool_instance.entity.identity
|
||||
if identity.provider:
|
||||
metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider
|
||||
|
||||
# Get icon using ToolManager for proper URL generation
|
||||
tenant_id = self.context.tenant_id
|
||||
if tenant_id and identity.provider:
|
||||
try:
|
||||
provider_type = tool_instance.tool_provider_type()
|
||||
icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider)
|
||||
if isinstance(icon, str):
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
elif isinstance(icon, dict):
|
||||
# Handle icon dict with background/content or light/dark variants
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
except Exception:
|
||||
# Fallback to identity.icon if ToolManager fails
|
||||
if identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
elif identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
return metadata
|
||||
|
||||
def _create_log(
|
||||
self,
|
||||
label: str,
|
||||
log_type: AgentLog.LogType,
|
||||
status: AgentLog.LogStatus,
|
||||
data: dict[str, Any] | None = None,
|
||||
parent_id: str | None = None,
|
||||
extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None,
|
||||
) -> AgentLog:
|
||||
"""Create a new AgentLog with standard metadata."""
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {
|
||||
AgentLog.LogMetadata.STARTED_AT: time.perf_counter(),
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
|
||||
return AgentLog(
|
||||
label=label,
|
||||
log_type=log_type,
|
||||
status=status,
|
||||
data=data or {},
|
||||
parent_id=parent_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _finish_log(
|
||||
self,
|
||||
log: AgentLog,
|
||||
data: dict[str, Any] | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
) -> AgentLog:
|
||||
"""Finish an AgentLog by updating its status and metadata."""
|
||||
log.status = AgentLog.LogStatus.SUCCESS
|
||||
|
||||
if data is not None:
|
||||
log.data = data
|
||||
|
||||
# Calculate elapsed time
|
||||
started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter())
|
||||
finished_at = time.perf_counter()
|
||||
|
||||
# Update metadata
|
||||
log.metadata = {
|
||||
**log.metadata,
|
||||
AgentLog.LogMetadata.FINISHED_AT: finished_at,
|
||||
# Calculate elapsed time in seconds
|
||||
AgentLog.LogMetadata.ELAPSED_TIME: round(finished_at - started_at, 4),
|
||||
}
|
||||
|
||||
# Add usage information if provided
|
||||
if usage:
|
||||
log.metadata.update(
|
||||
{
|
||||
AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price,
|
||||
AgentLog.LogMetadata.CURRENCY: usage.currency,
|
||||
AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens,
|
||||
AgentLog.LogMetadata.LLM_USAGE: usage,
|
||||
}
|
||||
)
|
||||
|
||||
return log
|
||||
|
||||
def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Replace file references in tool arguments with actual File objects.
|
||||
|
||||
Args:
|
||||
tool_args: Dictionary of tool arguments
|
||||
|
||||
Returns:
|
||||
Updated tool arguments with file references replaced
|
||||
"""
|
||||
# Process each argument in the dictionary
|
||||
processed_args: dict[str, Any] = {}
|
||||
for key, value in tool_args.items():
|
||||
processed_args[key] = self._process_file_reference(value)
|
||||
return processed_args
|
||||
|
||||
def _process_file_reference(self, data: Any) -> Any:
|
||||
"""
|
||||
Recursively process data to replace file references.
|
||||
Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...].
|
||||
|
||||
Args:
|
||||
data: The data to process (can be dict, list, str, or other types)
|
||||
|
||||
Returns:
|
||||
Processed data with file references replaced
|
||||
"""
|
||||
single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$")
|
||||
multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$")
|
||||
|
||||
if isinstance(data, dict):
|
||||
# Process dictionary recursively
|
||||
return {key: self._process_file_reference(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
# Process list recursively
|
||||
return [self._process_file_reference(item) for item in data]
|
||||
elif isinstance(data, str):
|
||||
# Check for single file pattern [File: file_id]
|
||||
single_match = single_file_pattern.match(data.strip())
|
||||
if single_match:
|
||||
file_id = single_match.group(1).strip()
|
||||
# Find the file in self.files
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
return file
|
||||
# If file not found, return original value
|
||||
return data
|
||||
|
||||
# Check for multiple files pattern [Files: file_id1, file_id2, ...]
|
||||
multiple_match = multiple_files_pattern.match(data.strip())
|
||||
if multiple_match:
|
||||
file_ids_str = multiple_match.group(1).strip()
|
||||
# Split by comma and strip whitespace
|
||||
file_ids = [fid.strip() for fid in file_ids_str.split(",")]
|
||||
|
||||
# Find all matching files
|
||||
matched_files: list[File] = []
|
||||
for file_id in file_ids:
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
matched_files.append(file)
|
||||
break
|
||||
|
||||
# Return list of files if any were found, otherwise return original
|
||||
return matched_files or data
|
||||
|
||||
return data
|
||||
else:
|
||||
# Return other types as-is
|
||||
return data
|
||||
|
||||
def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk:
|
||||
"""Create a text chunk for streaming."""
|
||||
return LLMResultChunk(
|
||||
model=self.model_instance.model_name,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=None,
|
||||
),
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
def _invoke_tool(
|
||||
self,
|
||||
tool_instance: Tool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> tuple[str, list[File], ToolInvokeMeta | None]:
|
||||
"""
|
||||
Invoke a tool and collect its response.
|
||||
|
||||
Args:
|
||||
tool_instance: The tool instance to invoke
|
||||
tool_args: Tool arguments
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Tuple of (response_content, tool_files, tool_invoke_meta)
|
||||
"""
|
||||
# Process tool_args to replace file references with actual File objects
|
||||
tool_args = self._replace_file_references(tool_args)
|
||||
|
||||
# If a tool invoke hook is set, use it instead of generic_invoke
|
||||
if self.tool_invoke_hook:
|
||||
response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name)
|
||||
# Note: message_file_ids are stored in DB, we don't convert them to File objects here
|
||||
# The caller (AgentAppRunner) handles file publishing
|
||||
return response_content, [], tool_invoke_meta
|
||||
|
||||
# Default: use generic_invoke for workflow scenarios
|
||||
# Import here to avoid circular import
|
||||
from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine
|
||||
|
||||
tool_response = ToolEngine.generic_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.context.user_id or "",
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=self.context.app_id,
|
||||
conversation_id=self.context.conversation_id,
|
||||
message_id=self.context.message_id,
|
||||
)
|
||||
|
||||
# Collect response and files
|
||||
response_content = ""
|
||||
tool_files: list[File] = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
response_content += response.message.text
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# Handle link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Link: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# Handle image URL messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK:
|
||||
# Handle image link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK:
|
||||
# Handle binary file link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
filename = response.meta.get("filename", "file") if response.meta else "file"
|
||||
response_content += f"[File: {filename} - {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
# Handle JSON messages
|
||||
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
|
||||
response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2)
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# Handle blob messages - convert to text representation
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobMessage):
|
||||
mime_type = (
|
||||
response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream"
|
||||
)
|
||||
size = len(response.message.blob)
|
||||
response_content += f"[Binary data: {mime_type}, size: {size} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
# Handle variable messages
|
||||
if isinstance(response.message, ToolInvokeMessage.VariableMessage):
|
||||
var_name = response.message.variable_name
|
||||
var_value = response.message.variable_value
|
||||
if isinstance(var_value, str):
|
||||
response_content += var_value
|
||||
else:
|
||||
response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
|
||||
# Handle blob chunk messages - these are parts of a larger blob
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage):
|
||||
response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
|
||||
# Handle retriever resources messages
|
||||
if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage):
|
||||
response_content += response.message.context
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.FILE:
|
||||
# Extract file from meta
|
||||
if response.meta and "file" in response.meta:
|
||||
file = response.meta["file"]
|
||||
if isinstance(file, File):
|
||||
# Check if file is for model or tool output
|
||||
if response.meta.get("target") == "self":
|
||||
# File is for model - add to files for next prompt
|
||||
self.files.append(file)
|
||||
response_content += f"File '{file.filename}' has been loaded into your context."
|
||||
else:
|
||||
# File is tool output
|
||||
tool_files.append(file)
|
||||
|
||||
return response_content, tool_files, None
|
||||
|
||||
def _validate_tool_args(self, tool_instance: Tool, tool_args: dict[str, Any]) -> str | None:
|
||||
"""Validate tool arguments against the tool's required parameters.
|
||||
|
||||
Checks that all required LLM-facing parameters are present and non-empty
|
||||
before actual execution, preventing wasted tool invocations when the model
|
||||
generates calls with missing arguments (e.g. empty ``{}``).
|
||||
|
||||
Returns:
|
||||
Error message if validation fails, None if all required parameters are satisfied.
|
||||
"""
|
||||
prompt_tool = tool_instance.to_prompt_message_tool()
|
||||
required_params: list[str] = prompt_tool.parameters.get("required", [])
|
||||
|
||||
if not required_params:
|
||||
return None
|
||||
|
||||
missing = [
|
||||
p
|
||||
for p in required_params
|
||||
if p not in tool_args
|
||||
or tool_args[p] is None
|
||||
or (isinstance(tool_args[p], str) and not tool_args[p].strip())
|
||||
]
|
||||
|
||||
if not missing:
|
||||
return None
|
||||
|
||||
return (
|
||||
f"Missing required parameter(s): {', '.join(missing)}. "
|
||||
f"Please provide all required parameters before calling this tool."
|
||||
)
|
||||
|
||||
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
|
||||
"""Find a tool instance by its name."""
|
||||
for tool in self.tools:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]:
|
||||
"""Convert tools to prompt message format."""
|
||||
prompt_tools: list[PromptMessageTool] = []
|
||||
for tool in self.tools:
|
||||
prompt_tools.append(tool.to_prompt_message_tool())
|
||||
return prompt_tools
|
||||
|
||||
def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None:
|
||||
"""Initialize usage tracking with empty usage if not set."""
|
||||
if "usage" not in llm_usage or llm_usage["usage"] is None:
|
||||
llm_usage["usage"] = LLMUsage.empty_usage()
|
||||
@ -1,358 +0,0 @@
|
||||
"""Function Call strategy implementation.
|
||||
|
||||
Implements the Function Call agent pattern where the LLM uses native tool-calling
|
||||
capability to invoke tools. Includes pre-execution parameter validation that
|
||||
intercepts invalid calls (e.g. empty arguments) before they reach tool backends,
|
||||
and avoids counting purely-invalid rounds against the iteration budget.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
)
|
||||
|
||||
from .base import AgentPattern
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FunctionCallStrategy(AgentPattern):
|
||||
"""Function Call strategy using model's native tool calling capability."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the function call agent strategy."""
|
||||
# Convert tools to prompt format
|
||||
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
|
||||
|
||||
# Initialize tracking
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
function_call_state: bool = True
|
||||
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
# Consecutive rounds where ALL tool calls failed parameter validation.
|
||||
# When this happens the round is "free" (iteration_step not incremented)
|
||||
# up to a safety cap to prevent infinite loops.
|
||||
consecutive_validation_failures: int = 0
|
||||
max_validation_retries: int = 3
|
||||
|
||||
while function_call_state and iteration_step <= max_iterations:
|
||||
function_call_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, remove tools to force final answer
|
||||
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model_name} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=current_tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log
|
||||
)
|
||||
messages.append(self._create_assistant_message(response_content, tool_calls))
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update final text if no tool calls (this is likely the final answer)
|
||||
if not tool_calls:
|
||||
final_text = response_content
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Process tool calls
|
||||
tool_outputs: dict[str, str] = {}
|
||||
all_validation_errors: bool = True
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
# Execute tools (with pre-execution parameter validation)
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
tool_response, tool_files, _, is_validation_error = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
output_files.extend(tool_files)
|
||||
if not is_validation_error:
|
||||
all_validation_errors = False
|
||||
else:
|
||||
all_validation_errors = False
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"llm_result": response_content,
|
||||
"tool_calls": [
|
||||
{"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
"final_answer": final_text if not function_call_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
|
||||
# Skip iteration counter when every tool call in this round failed validation,
|
||||
# giving the model a free retry — but cap retries to prevent infinite loops.
|
||||
if tool_calls and all_validation_errors:
|
||||
consecutive_validation_failures += 1
|
||||
if consecutive_validation_failures >= max_validation_retries:
|
||||
logger.warning(
|
||||
"Agent hit %d consecutive validation-only rounds, forcing iteration increment",
|
||||
consecutive_validation_failures,
|
||||
)
|
||||
iteration_step += 1
|
||||
consecutive_validation_failures = 0
|
||||
else:
|
||||
logger.info(
|
||||
"All tool calls failed validation (attempt %d/%d), not counting iteration",
|
||||
consecutive_validation_failures,
|
||||
max_validation_retries,
|
||||
)
|
||||
else:
|
||||
consecutive_validation_failures = 0
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, LLMUsage | None],
|
||||
start_log: AgentLog,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[list[tuple[str, str, dict[str, Any]]], str, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract tool calls and content.
|
||||
|
||||
Returns a tuple of (tool_calls, response_content, finish_reason).
|
||||
"""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
response_content: str = ""
|
||||
finish_reason: str | None = None
|
||||
if not isinstance(chunks, LLMResult):
|
||||
# Streaming response
|
||||
for chunk in chunks:
|
||||
# Extract tool calls
|
||||
if self._has_tool_calls(chunk):
|
||||
tool_calls.extend(self._extract_tool_calls(chunk))
|
||||
|
||||
# Extract content
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
response_content += self._extract_content(chunk.delta.message.content)
|
||||
|
||||
# Track usage
|
||||
if chunk.delta.usage:
|
||||
self._accumulate_usage(llm_usage, chunk.delta.usage)
|
||||
|
||||
# Capture finish reason
|
||||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
# Non-streaming response
|
||||
result: LLMResult = chunks
|
||||
|
||||
if self._has_tool_calls_result(result):
|
||||
tool_calls.extend(self._extract_tool_calls_result(result))
|
||||
|
||||
if result.message and result.message.content:
|
||||
response_content += self._extract_content(result.message.content)
|
||||
|
||||
if result.usage:
|
||||
self._accumulate_usage(llm_usage, result.usage)
|
||||
|
||||
# Convert to streaming format
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
yield self._finish_log(
|
||||
start_log,
|
||||
data={
|
||||
"result": response_content,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
return tool_calls, response_content, finish_reason
|
||||
|
||||
def _create_assistant_message(
|
||||
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
||||
) -> AssistantPromptMessage:
|
||||
"""Create assistant message with tool calls."""
|
||||
if tool_calls is None:
|
||||
return AssistantPromptMessage(content=content)
|
||||
return AssistantPromptMessage(
|
||||
content=content or "",
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tc[0],
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])),
|
||||
)
|
||||
for tc in tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
tool_call_id: str,
|
||||
messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None, bool]]:
|
||||
"""Handle a single tool call and return response with files, meta, and validation status.
|
||||
|
||||
Validates required parameters before execution. When validation fails the tool
|
||||
is never invoked — a synthetic error is fed back to the model so it can self-correct
|
||||
without consuming a real iteration.
|
||||
|
||||
Returns:
|
||||
(response_content, tool_files, tool_invoke_meta, is_validation_error).
|
||||
``is_validation_error`` is True when the call was rejected due to missing
|
||||
required parameters, allowing the caller to skip the iteration counter.
|
||||
"""
|
||||
# Find tool
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
if not tool_instance:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
# Get tool metadata (provider, icon, etc.)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance)
|
||||
|
||||
# Create tool call log
|
||||
tool_call_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_call_log
|
||||
|
||||
# Validate required parameters before execution to avoid wasted invocations
|
||||
validation_error = self._validate_tool_args(tool_instance, tool_args)
|
||||
if validation_error:
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = validation_error
|
||||
tool_call_log.data = {**tool_call_log.data, "error": validation_error}
|
||||
yield tool_call_log
|
||||
|
||||
messages.append(ToolPromptMessage(content=validation_error, tool_call_id=tool_call_id, name=tool_name))
|
||||
return validation_error, [], None, True
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
|
||||
|
||||
yield self._finish_log(
|
||||
tool_call_log,
|
||||
data={
|
||||
**tool_call_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
final_content = response_content or "Tool executed successfully"
|
||||
# Add tool response to messages
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=final_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return response_content, tool_files, tool_invoke_meta, False
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = error_message
|
||||
tool_call_log.data = {
|
||||
**tool_call_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_call_log
|
||||
|
||||
# Add error message to conversation
|
||||
error_content = f"Tool execution failed: {error_message}"
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=error_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return error_content, [], None, False
|
||||
@ -1,418 +0,0 @@
|
||||
"""ReAct strategy implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.model_manager import ModelInstance
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
)
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class ReActStrategy(AgentPattern):
|
||||
"""ReAct strategy using reasoning and acting approach."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
):
|
||||
"""Initialize the ReAct strategy with instruction support."""
|
||||
super().__init__(
|
||||
model_instance=model_instance,
|
||||
tools=tools,
|
||||
context=context,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
files=files,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
self.instruction = instruction
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the ReAct agent strategy."""
|
||||
# Initialize tracking
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
react_state: bool = True
|
||||
total_usage: dict[str, Any] = {"usage": None}
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Add "Observation" to stop sequences
|
||||
if "Observation" not in stop:
|
||||
stop = stop.copy()
|
||||
stop.append("Observation")
|
||||
|
||||
while react_state and iteration_step <= max_iterations:
|
||||
react_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
|
||||
# Build prompt with/without tools based on iteration
|
||||
include_tools = iteration_step < max_iterations
|
||||
current_messages = self._build_prompt_with_react_format(
|
||||
prompt_messages, agent_scratchpad, include_tools, self.instruction
|
||||
)
|
||||
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model_name} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, Any] = {"usage": None}
|
||||
|
||||
# Use current messages directly (files are handled by base class if needed)
|
||||
messages_to_use = current_messages
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log, current_messages
|
||||
)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Check if we have an action to execute
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
react_state = True
|
||||
# Execute tool
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
# Add observation to scratchpad for display
|
||||
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
|
||||
else:
|
||||
# Extract final answer
|
||||
if scratchpad.action and scratchpad.action.action_input:
|
||||
final_answer = scratchpad.action.action_input
|
||||
if isinstance(final_answer, dict):
|
||||
final_answer = json.dumps(final_answer, ensure_ascii=False)
|
||||
final_text = str(final_answer)
|
||||
elif scratchpad.thought:
|
||||
# If no action but we have thought, use thought as final answer
|
||||
final_text = scratchpad.thought
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
"observation": scratchpad.observation or None,
|
||||
"final_answer": final_text if not react_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
|
||||
)
|
||||
|
||||
def _build_prompt_with_react_format(
|
||||
self,
|
||||
original_messages: list[PromptMessage],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
include_tools: bool = True,
|
||||
instruction: str = "",
|
||||
) -> list[PromptMessage]:
|
||||
"""Build prompt messages with ReAct format."""
|
||||
# Copy messages to avoid modifying original
|
||||
messages = list(original_messages)
|
||||
|
||||
# Find and update the system prompt that should already exist
|
||||
system_prompt_found = False
|
||||
for i, msg in enumerate(messages):
|
||||
if isinstance(msg, SystemPromptMessage):
|
||||
system_prompt_found = True
|
||||
# The system prompt from frontend already has the template, just replace placeholders
|
||||
|
||||
# Format tools
|
||||
tools_str = ""
|
||||
tool_names = []
|
||||
if include_tools and self.tools:
|
||||
# Convert tools to prompt message tools format
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
|
||||
tool_names = [tool.name for tool in prompt_tools]
|
||||
|
||||
# Format tools as JSON for comprehensive information
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2)
|
||||
tool_names_str = ", ".join(f'"{name}"' for name in tool_names)
|
||||
else:
|
||||
tools_str = "No tools available"
|
||||
tool_names_str = ""
|
||||
|
||||
# Replace placeholders in the existing system prompt
|
||||
updated_content = msg.content
|
||||
assert isinstance(updated_content, str)
|
||||
updated_content = updated_content.replace("{{instruction}}", instruction)
|
||||
updated_content = updated_content.replace("{{tools}}", tools_str)
|
||||
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
|
||||
|
||||
# Create new SystemPromptMessage with updated content
|
||||
messages[i] = SystemPromptMessage(content=updated_content)
|
||||
break
|
||||
|
||||
# If no system prompt found, that's unexpected but add scratchpad anyway
|
||||
if not system_prompt_found:
|
||||
# This shouldn't happen if frontend is working correctly
|
||||
pass
|
||||
|
||||
# Format agent scratchpad
|
||||
scratchpad_str = ""
|
||||
if agent_scratchpad:
|
||||
scratchpad_parts: list[str] = []
|
||||
for unit in agent_scratchpad:
|
||||
if unit.thought:
|
||||
scratchpad_parts.append(f"Thought: {unit.thought}")
|
||||
if unit.action_str:
|
||||
scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```")
|
||||
if unit.observation:
|
||||
scratchpad_parts.append(f"Observation: {unit.observation}")
|
||||
scratchpad_str = "\n".join(scratchpad_parts)
|
||||
|
||||
# If there's a scratchpad, append it to the last message
|
||||
if scratchpad_str:
|
||||
messages.append(AssistantPromptMessage(content=scratchpad_str))
|
||||
|
||||
return messages
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, Any],
|
||||
model_log: AgentLog,
|
||||
current_messages: list[PromptMessage],
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[AgentScratchpadUnit, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract action/thought.
|
||||
|
||||
Returns a tuple of (scratchpad_unit, finish_reason).
|
||||
"""
|
||||
usage_dict: dict[str, Any] = {}
|
||||
|
||||
# Convert non-streaming to streaming format if needed
|
||||
if isinstance(chunks, LLMResult):
|
||||
result = chunks
|
||||
|
||||
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=result.message,
|
||||
usage=result.usage,
|
||||
finish_reason=None,
|
||||
),
|
||||
system_fingerprint=result.system_fingerprint or "",
|
||||
)
|
||||
|
||||
streaming_chunks = result_to_chunks()
|
||||
else:
|
||||
streaming_chunks = chunks
|
||||
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
|
||||
|
||||
# Initialize scratchpad unit
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Process chunks
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
# Action detected
|
||||
action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + action_str
|
||||
scratchpad.action_str = action_str
|
||||
scratchpad.action = chunk
|
||||
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
else:
|
||||
# Text chunk
|
||||
chunk_text = str(chunk)
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
||||
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
|
||||
# Update usage
|
||||
if usage_dict.get("usage"):
|
||||
if llm_usage.get("usage"):
|
||||
self._accumulate_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
llm_usage["usage"] = usage_dict["usage"]
|
||||
|
||||
# Clean up thought
|
||||
scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you"
|
||||
|
||||
# Finish model log
|
||||
yield self._finish_log(
|
||||
model_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
|
||||
return scratchpad, finish_reason
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
prompt_messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File]]]:
|
||||
"""Handle tool call and return observation with files."""
|
||||
tool_name = action.action_name
|
||||
tool_args: dict[str, Any] | str = action.action_input
|
||||
|
||||
# Find tool instance first to get metadata
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {}
|
||||
|
||||
# Start tool log with tool metadata
|
||||
tool_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_log
|
||||
|
||||
if not tool_instance:
|
||||
# Finish tool log with error
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"error": f"Tool {tool_name} not found",
|
||||
},
|
||||
)
|
||||
return f"Tool {tool_name} not found", []
|
||||
|
||||
# Ensure tool_args is a dict
|
||||
tool_args_dict: dict[str, Any]
|
||||
if isinstance(tool_args, str):
|
||||
try:
|
||||
tool_args_dict = json.loads(tool_args)
|
||||
except json.JSONDecodeError:
|
||||
tool_args_dict = {"input": tool_args}
|
||||
elif not isinstance(tool_args, dict):
|
||||
tool_args_dict = {"input": str(tool_args)}
|
||||
else:
|
||||
tool_args_dict = tool_args
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
|
||||
|
||||
# Finish tool log
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
|
||||
return response_content or "Tool executed successfully", tool_files
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_log.error = error_message
|
||||
tool_log.data = {
|
||||
**tool_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_log
|
||||
|
||||
return f"Tool execution failed: {error_message}", []
|
||||
@ -1,108 +0,0 @@
|
||||
"""Strategy factory for creating agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.model_manager import ModelInstance
|
||||
from graphon.file.models import File
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
"""Factory for creating agent strategies based on model features."""
|
||||
|
||||
# Tool calling related features
|
||||
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
|
||||
|
||||
@staticmethod
|
||||
def create_strategy(
|
||||
model_features: list[ModelFeature],
|
||||
model_instance: ModelInstance,
|
||||
context: ExecutionContext,
|
||||
tools: list[Tool],
|
||||
files: list[File],
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
agent_strategy: AgentEntity.Strategy | None = None,
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
) -> AgentPattern:
|
||||
"""
|
||||
Create an appropriate strategy based on model features.
|
||||
|
||||
Args:
|
||||
model_features: List of model features/capabilities
|
||||
model_instance: Model instance to use
|
||||
context: Execution context containing trace/audit information
|
||||
tools: Available tools
|
||||
files: Available files
|
||||
max_iterations: Maximum iterations for the strategy
|
||||
workflow_call_depth: Depth of workflow calls
|
||||
agent_strategy: Optional explicit strategy override
|
||||
tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke)
|
||||
instruction: Optional instruction for ReAct strategy
|
||||
|
||||
Returns:
|
||||
AgentStrategy instance
|
||||
"""
|
||||
|
||||
# If explicit strategy is provided and it's Function Calling, try to use it if supported
|
||||
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
# Fallback to ReAct if FC is requested but not supported
|
||||
|
||||
# If explicit strategy is Chain of Thought (ReAct)
|
||||
if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Default auto-selection logic
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
# Model supports native function calling
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
else:
|
||||
# Use ReAct strategy for models without function calling
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
@ -84,7 +84,7 @@ class AgentStrategyEntity(BaseModel):
|
||||
identity: AgentStrategyIdentity
|
||||
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The description of the agent strategy")
|
||||
output_schema: dict | None = None
|
||||
output_schema: dict[str, Any] | None = None
|
||||
features: list[AgentFeature] | None = None
|
||||
meta_version: str | None = None
|
||||
# pydantic configs
|
||||
|
||||
@ -22,8 +22,8 @@ class SensitiveWordAvoidanceConfigManager:
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
|
||||
) -> tuple[dict, list[str]]:
|
||||
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
if not config.get("sensitive_word_avoidance"):
|
||||
config["sensitive_word_avoidance"] = {"enabled": False}
|
||||
|
||||
|
||||
@ -138,7 +138,9 @@ class DatasetConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for dataset feature
|
||||
|
||||
@ -172,7 +174,7 @@ class DatasetConfigManager:
|
||||
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
|
||||
|
||||
@classmethod
|
||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
|
||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]):
|
||||
"""
|
||||
Extract dataset config for legacy compatibility
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ class ModelConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for model config
|
||||
|
||||
@ -108,7 +108,7 @@ class ModelConfigManager:
|
||||
return dict(config), ["model"]
|
||||
|
||||
@classmethod
|
||||
def validate_model_completion_params(cls, cp: dict):
|
||||
def validate_model_completion_params(cls, cp: dict[str, Any]):
|
||||
# model.completion_params
|
||||
if not isinstance(cp, dict):
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
@ -65,7 +65,7 @@ class PromptTemplateConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate pre_prompt and set defaults for prompt feature
|
||||
depending on the config['model']
|
||||
@ -130,7 +130,7 @@ class PromptTemplateConfigManager:
|
||||
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
|
||||
|
||||
@classmethod
|
||||
def validate_post_prompt_and_set_defaults(cls, config: dict):
|
||||
def validate_post_prompt_and_set_defaults(cls, config: dict[str, Any]):
|
||||
"""
|
||||
Validate post_prompt and set defaults for prompt feature
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
@ -82,7 +82,7 @@ class BasicVariablesConfigManager:
|
||||
return variable_entities, external_data_variables
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for user input form
|
||||
|
||||
@ -99,7 +99,7 @@ class BasicVariablesConfigManager:
|
||||
return config, related_config_keys
|
||||
|
||||
@classmethod
|
||||
def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_variables_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for user input form
|
||||
|
||||
@ -164,7 +164,9 @@ class BasicVariablesConfigManager:
|
||||
return config, ["user_input_form"]
|
||||
|
||||
@classmethod
|
||||
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_external_data_tools_and_set_defaults(
|
||||
cls, tenant_id: str, config: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for external data fetch feature
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ class FileUploadConfigManager:
|
||||
return FileUploadConfig.model_validate(file_upload_dict)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for file upload feature
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
|
||||
@ -13,7 +15,7 @@ class AppConfigModel(BaseModel):
|
||||
|
||||
class MoreLikeThisConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -23,7 +25,7 @@ class MoreLikeThisConfigManager:
|
||||
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
try:
|
||||
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
|
||||
except ValidationError:
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class OpeningStatementConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> tuple[str, list]:
|
||||
def convert(cls, config: dict[str, Any]) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,7 +18,7 @@ class OpeningStatementConfigManager:
|
||||
return opening_statement, suggested_questions_list
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for opening statement feature
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class RetrievalResourceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
show_retrieve_source = False
|
||||
retriever_resource_dict = config.get("retriever_resource")
|
||||
if retriever_resource_dict:
|
||||
@ -10,7 +13,7 @@ class RetrievalResourceConfigManager:
|
||||
return show_retrieve_source
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for retriever resource feature
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SpeechToTextConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,7 +18,7 @@ class SpeechToTextConfigManager:
|
||||
return speech_to_text
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for speech to text feature
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,7 +18,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
return suggested_questions_after_answer
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for suggested questions feature
|
||||
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import TextToSpeechEntity
|
||||
|
||||
|
||||
class TextToSpeechConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict):
|
||||
def convert(cls, config: dict[str, Any]):
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -22,7 +24,7 @@ class TextToSpeechConfigManager:
|
||||
return text_to_speech
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for text to speech feature
|
||||
|
||||
|
||||
@ -177,14 +177,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
|
||||
# Resolve parent_message_id for thread continuity
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
parent_message_id: str | None = UUID_NIL
|
||||
else:
|
||||
parent_message_id = args.get("parent_message_id")
|
||||
if not parent_message_id and conversation:
|
||||
parent_message_id = self._resolve_latest_message_id(conversation.id)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
@ -196,7 +188,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=parent_message_id,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
@ -697,17 +689,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
else:
|
||||
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _resolve_latest_message_id(conversation_id: str) -> str | None:
|
||||
"""Auto-resolve parent_message_id to the latest message when client doesn't provide one."""
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = (
|
||||
select(Message.id)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
latest_id = db.session.scalar(stmt)
|
||||
return str(latest_id) if latest_id else None
|
||||
|
||||
@ -57,7 +57,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
) -> Generator[dict[str, Any] | str, Any, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -88,7 +88,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
) -> Generator[dict[str, Any] | str, Any, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
@ -189,8 +192,24 @@ class AgentChatAppRunner(AppRunner):
|
||||
message_result = db.session.scalar(msg_stmt)
|
||||
if message_result is None:
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
runner = AgentAppRunner(
|
||||
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
# check LLM mode
|
||||
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
|
||||
runner_cls = CotChatAgentRunner
|
||||
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
|
||||
runner_cls = CotCompletionAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
|
||||
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
runner_cls = FunctionCallAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
|
||||
|
||||
runner = runner_cls(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation_result,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -56,7 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -87,7 +87,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@ -24,7 +24,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
|
||||
def _generate_full_response() -> Generator[dict | str, Any, None]:
|
||||
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||
yield from cls.convert_stream_full_response(response)
|
||||
|
||||
return _generate_full_response()
|
||||
@ -33,7 +33,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
else:
|
||||
|
||||
def _generate_simple_response() -> Generator[dict | str, Any, None]:
|
||||
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||
yield from cls.convert_stream_simple_response(response)
|
||||
|
||||
return _generate_simple_response()
|
||||
@ -52,14 +52,14 @@ class AppGenerateResponseConverter(ABC):
|
||||
@abstractmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -56,7 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -87,7 +87,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@ -1,53 +0,0 @@
|
||||
"""Legacy Response Adapter for transparent upgrade.
|
||||
|
||||
When old apps (chat/completion/agent-chat) run through the Agent V2
|
||||
workflow engine via transparent upgrade, the SSE events are in workflow
|
||||
format (workflow_started, node_started, etc.). This adapter filters out
|
||||
workflow-specific events and passes through only the events that old
|
||||
clients expect (message, message_end, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WORKFLOW_ONLY_EVENTS = frozenset({
|
||||
"workflow_started",
|
||||
"workflow_finished",
|
||||
"node_started",
|
||||
"node_finished",
|
||||
"iteration_started",
|
||||
"iteration_next",
|
||||
"iteration_completed",
|
||||
})
|
||||
|
||||
|
||||
def adapt_workflow_stream_for_legacy(
|
||||
stream: Generator[str, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
"""Filter workflow-specific SSE events from a streaming response.
|
||||
|
||||
Passes through message, message_end, agent_log, error, ping events.
|
||||
Suppresses workflow_started, workflow_finished, node_started, node_finished.
|
||||
|
||||
This makes the SSE stream look more like what old easy-UI apps produce,
|
||||
while still carrying the actual LLM response content.
|
||||
"""
|
||||
for chunk in stream:
|
||||
if not chunk or not chunk.strip():
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
try:
|
||||
if chunk.startswith("data: "):
|
||||
data = json.loads(chunk[6:])
|
||||
event = data.get("event", "")
|
||||
if event in WORKFLOW_ONLY_EVENTS:
|
||||
continue
|
||||
yield chunk
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
yield chunk
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -55,7 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -85,7 +85,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -37,7 +37,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -66,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
|
||||
@ -34,7 +36,9 @@ class PipelineConfigManager(BaseAppConfigManager):
|
||||
return pipeline_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||
def config_validate(
|
||||
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Validate for pipeline config
|
||||
|
||||
|
||||
@ -782,7 +782,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
user_id: str,
|
||||
all_files: list,
|
||||
datasource_info: Mapping[str, Any],
|
||||
next_page_parameters: dict | None = None,
|
||||
next_page_parameters: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Get files in a folder.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -37,7 +37,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -66,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@ -682,15 +682,16 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None):
|
||||
invoke_from = self._application_generate_entity.invoke_from
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
elif invoke_from == InvokeFrom.EXPLORE:
|
||||
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
|
||||
elif invoke_from == InvokeFrom.WEB_APP:
|
||||
created_from = WorkflowAppLogCreatedFrom.WEB_APP
|
||||
else:
|
||||
# not save log for debugging
|
||||
return
|
||||
match invoke_from:
|
||||
case InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
case InvokeFrom.EXPLORE:
|
||||
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
|
||||
case InvokeFrom.WEB_APP:
|
||||
created_from = WorkflowAppLogCreatedFrom.WEB_APP
|
||||
case InvokeFrom.DEBUGGER | InvokeFrom.TRIGGER | InvokeFrom.PUBLISHED_PIPELINE | InvokeFrom.VALIDATION:
|
||||
# not save log for debugging
|
||||
return
|
||||
|
||||
if not workflow_run_id:
|
||||
return
|
||||
|
||||
@ -146,6 +146,8 @@ class WorkflowBasedAppRunner:
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Use the provided graph_runtime_state for consistent state management
|
||||
|
||||
node_factory = DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@ -1,72 +0,0 @@
|
||||
"""
|
||||
LLM Generation Detail entities.
|
||||
|
||||
Defines the structure for storing and transmitting LLM generation details
|
||||
including reasoning content, tool calls, and their sequence.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ContentSegment(BaseModel):
|
||||
"""Represents a content segment in the generation sequence."""
|
||||
|
||||
type: Literal["content"] = "content"
|
||||
start: int = Field(..., description="Start position in the text")
|
||||
end: int = Field(..., description="End position in the text")
|
||||
|
||||
|
||||
class ReasoningSegment(BaseModel):
|
||||
"""Represents a reasoning segment in the generation sequence."""
|
||||
|
||||
type: Literal["reasoning"] = "reasoning"
|
||||
index: int = Field(..., description="Index into reasoning_content array")
|
||||
|
||||
|
||||
class ToolCallSegment(BaseModel):
|
||||
"""Represents a tool call segment in the generation sequence."""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
index: int = Field(..., description="Index into tool_calls array")
|
||||
|
||||
|
||||
SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment
|
||||
|
||||
|
||||
class ToolCallDetail(BaseModel):
|
||||
"""Represents a tool call with its arguments and result."""
|
||||
|
||||
id: str = Field(default="", description="Unique identifier for the tool call")
|
||||
name: str = Field(..., description="Name of the tool")
|
||||
arguments: str = Field(default="", description="JSON string of tool arguments")
|
||||
result: str = Field(default="", description="Result from the tool execution")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds")
|
||||
icon: str | dict | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
|
||||
|
||||
class LLMGenerationDetailData(BaseModel):
|
||||
"""
|
||||
Domain model for LLM generation detail.
|
||||
|
||||
Contains the structured data for reasoning content, tool calls,
|
||||
and their display sequence.
|
||||
"""
|
||||
|
||||
reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments")
|
||||
tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details")
|
||||
sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments")
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if there's any meaningful generation detail."""
|
||||
return not self.reasoning_content and not self.tool_calls
|
||||
|
||||
def to_response_dict(self) -> dict:
|
||||
"""Convert to dictionary for API response."""
|
||||
return {
|
||||
"reasoning_content": self.reasoning_content,
|
||||
"tool_calls": [tc.model_dump() for tc in self.tool_calls],
|
||||
"sequence": [seg.model_dump() for seg in self.sequence],
|
||||
}
|
||||
@ -521,7 +521,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
node_type: str
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
@ -547,7 +547,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
@ -571,7 +571,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict | None = None
|
||||
extras: dict[str, Any] | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
@ -602,7 +602,7 @@ class LoopNodeStartStreamResponse(StreamResponse):
|
||||
node_type: str
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
@ -653,7 +653,7 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict | None = None
|
||||
extras: dict[str, Any] | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user