Compare commits

..

1 Commits

Author SHA1 Message Date
50156f822b add podman compose middleware helpers 2026-04-02 02:29:38 +08:00
4189 changed files with 57082 additions and 112701 deletions

View File

@ -1,79 +0,0 @@
---
name: e2e-cucumber-playwright
description: Write, update, or review Dify end-to-end tests under `e2e/` that use Cucumber, Gherkin, and Playwright. Use when the task involves `.feature` files, `features/step-definitions/`, `features/support/`, `DifyWorld`, scenario tags, locator/assertion choices, or E2E testing best practices for this repository.
---
# Dify E2E Cucumber + Playwright
Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) as the canonical guide for local architecture and conventions, then apply Playwright/Cucumber best practices only where they fit the current suite.
## Scope
- Use this skill for `.feature` files, Cucumber step definitions, `DifyWorld`, hooks, tags, and E2E review work under `e2e/`.
- Do not use this skill for Vitest or React Testing Library work under `web/`; use `frontend-testing` instead.
- Do not use this skill for backend test or API review tasks under `api/`.
## Read Order
1. Read [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) first.
2. Read only the files directly involved in the task:
- target `.feature` files under `e2e/features/`
- related step files under `e2e/features/step-definitions/`
- `e2e/features/support/hooks.ts` and `e2e/features/support/world.ts` when session lifecycle or shared state matters
- `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter
3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved.
4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved.
5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern.
## Local Rules
- `e2e/` uses Cucumber for scenarios and Playwright as the browser layer.
- `DifyWorld` is the per-scenario context object. Type `this` as `DifyWorld` and use `async function`, not arrow functions.
- Keep glue organized by capability under `e2e/features/step-definitions/`; use `common/` only for broadly reusable steps.
- Browser session behavior comes from `features/support/hooks.ts`:
- default: authenticated session with shared storage state
- `@unauthenticated`: clean browser context
- `@authenticated`: readability/selective-run tag only unless implementation changes
- `@fresh`: only for `e2e:full*` flows
- Do not import Playwright Test runner patterns that bypass the current Cucumber + `DifyWorld` architecture unless the task is explicitly about changing that architecture.
## Workflow
1. Rebuild local context.
- Inspect the target feature area.
- Reuse an existing step when wording and behavior already match.
- Add a new step only for a genuinely new user action or assertion.
- Keep edits close to the current capability folder unless the step is broadly reusable.
2. Write behavior-first scenarios.
- Describe user-observable behavior, not DOM mechanics.
- Keep each scenario focused on one workflow or outcome.
- Keep scenarios independent and re-runnable.
3. Write step definitions in the local style.
- Keep one step to one user-visible action or one assertion.
- Prefer Cucumber Expressions such as `{string}` and `{int}`.
- Scope locators to stable containers when the page has repeated elements.
- Avoid page-object layers or extra helper abstractions unless repeated complexity clearly justifies them.
4. Use Playwright in the local style.
- Prefer user-facing locators: `getByRole`, `getByLabel`, `getByPlaceholder`, `getByText`, then `getByTestId` for explicit contracts.
- Use web-first `expect(...)` assertions.
- Do not use `waitForTimeout`, manual polling, or raw visibility checks when a locator action or retrying assertion already expresses the behavior.
5. Validate narrowly.
- Run the narrowest tagged scenario or flow that exercises the change.
- Run `pnpm -C e2e check`.
- Broaden verification only when the change affects hooks, tags, setup, or shared step semantics.
## Review Checklist
- Does the scenario describe behavior rather than implementation?
- Does it fit the current session model, tags, and `DifyWorld` usage?
- Should an existing step be reused instead of adding a new one?
- Are locators user-facing and assertions web-first?
- Does the change introduce hidden coupling across scenarios, tags, or instance state?
- Does it document or implement behavior that differs from the real hooks or configuration?
Lead findings with correctness, flake risk, and architecture drift.
## References
- [`references/playwright-best-practices.md`](references/playwright-best-practices.md)
- [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md)

View File

@ -1,4 +0,0 @@
interface:
display_name: "E2E Cucumber + Playwright"
short_description: "Write and review Dify E2E scenarios."
default_prompt: "Use $e2e-cucumber-playwright to write or review a Dify E2E scenario under e2e/."

View File

@ -1,93 +0,0 @@
# Cucumber Best Practices For Dify E2E
Use this reference when writing or reviewing Gherkin scenarios, step definitions, parameter expressions, and step reuse in Dify's `e2e/` suite.
Official sources:
- https://cucumber.io/docs/guides/10-minute-tutorial/
- https://cucumber.io/docs/cucumber/step-definitions/
- https://cucumber.io/docs/cucumber/cucumber-expressions/
## What Matters Most
### 1. Treat scenarios as executable specifications
Cucumber scenarios should describe examples of behavior, not test implementation recipes.
Apply it like this:
- write what the user does and what should happen
- avoid UI-internal wording such as selector details, DOM structure, or component names
- keep language concrete enough that the scenario reads like living documentation
### 2. Keep scenarios focused
A scenario should usually prove one workflow or business outcome. If a scenario wanders across several unrelated behaviors, split it.
In Dify's suite, this means:
- one capability-focused scenario per feature path
- no long setup chains when existing bootstrap or reusable steps already cover them
- no hidden dependency on another scenario's side effects
### 3. Reuse steps, but only when behavior really matches
Good reuse reduces duplication. Bad reuse hides meaning.
Prefer reuse when:
- the user action is genuinely the same
- the expected outcome is genuinely the same
- the wording stays natural across features
Write a new step when:
- the behavior is materially different
- reusing the old wording would make the scenario misleading
- a supposedly generic step would become an implementation-detail wrapper
### 4. Prefer Cucumber Expressions
Use Cucumber Expressions for parameters unless regex is clearly necessary.
Common examples:
- `{string}` for labels, names, and visible text
- `{int}` for counts
- `{float}` for decimal values
- `{word}` only when the value is truly a single token
Keep expressions readable. If a step needs complicated parsing logic, first ask whether the scenario wording should be simpler.
### 5. Keep step definitions thin and meaningful
Step definitions are glue between Gherkin and automation, not a second abstraction language.
For Dify:
- type `this` as `DifyWorld`
- use `async function`
- keep each step to one user-visible action or assertion
- rely on `DifyWorld` and existing support code for shared context
- avoid leaking cross-scenario state
### 6. Use tags intentionally
Tags should communicate run scope or session semantics, not become ad hoc metadata.
In Dify's current suite:
- capability tags group related scenarios
- `@unauthenticated` changes session behavior
- `@authenticated` is descriptive/selective, not a behavior switch by itself
- `@fresh` belongs to reset/full-install flows only
If a proposed tag implies behavior, verify that hooks or runner configuration actually implement it.
## Review Questions
- Does the scenario read like a real example of product behavior?
- Are the steps behavior-oriented instead of implementation-oriented?
- Is a reused step still truthful in this feature?
- Is a new tag documenting real behavior, or inventing semantics that the suite does not implement?
- Would a new reader understand the outcome without opening the step-definition file?

View File

@ -1,96 +0,0 @@
# Playwright Best Practices For Dify E2E
Use this reference when writing or reviewing locator, assertion, isolation, or synchronization logic for Dify's Cucumber-based E2E suite.
Official sources:
- https://playwright.dev/docs/best-practices
- https://playwright.dev/docs/locators
- https://playwright.dev/docs/test-assertions
- https://playwright.dev/docs/browser-contexts
## What Matters Most
### 1. Keep scenarios isolated
Playwright's model is built around clean browser contexts so one test does not leak into another. In Dify's suite, that principle maps to per-scenario session setup in `features/support/hooks.ts` and `DifyWorld`.
Apply it like this:
- do not depend on another scenario having run first
- do not persist ad hoc scenario state outside `DifyWorld`
- do not couple ordinary scenarios to `@fresh` behavior
- when a flow needs special auth/session semantics, express that through the existing tag model or explicit hook changes
### 2. Prefer user-facing locators
Playwright recommends built-in locators that reflect what users perceive on the page.
Preferred order in this repository:
1. `getByRole`
2. `getByLabel`
3. `getByPlaceholder`
4. `getByText`
5. `getByTestId` when an explicit test contract is the most stable option
Avoid raw CSS/XPath selectors unless no stable user-facing contract exists and adding one is not practical.
Also remember:
- repeated content usually needs scoping to a stable container
- exact text matching is often too brittle when role/name or label already exists
- `getByTestId` is acceptable when semantics are weak but the contract is intentional
### 3. Use web-first assertions
Playwright assertions auto-wait and retry. Prefer them over manual state inspection.
Prefer:
- `await expect(page).toHaveURL(...)`
- `await expect(locator).toBeVisible()`
- `await expect(locator).toBeHidden()`
- `await expect(locator).toBeEnabled()`
- `await expect(locator).toHaveText(...)`
Avoid:
- `expect(await locator.isVisible()).toBe(true)`
- custom polling loops for DOM state
- `waitForTimeout` as synchronization
If a condition genuinely needs custom retry logic, use Playwright's polling/assertion tools deliberately and keep that choice local and explicit.
### 4. Let actions wait for actionability
Locator actions already wait for the element to be actionable. Do not preface every click/fill with extra timing logic unless the action needs a specific visible/ready assertion for clarity.
Good pattern:
- assert a meaningful visible state when that is part of the behavior
- then click/fill/select via locator APIs
Bad pattern:
- stack arbitrary waits before every action
- wait on unstable implementation details instead of the visible state the user cares about
### 5. Match debugging to the current suite
Playwright's wider ecosystem supports traces and rich debugging tools. Dify's current suite already captures:
- full-page screenshots
- page HTML
- console errors
- page errors
Use the existing artifact flow by default. If a task is specifically about improving diagnostics, confirm the change fits the current Cucumber architecture before importing broader Playwright tooling.
## Review Questions
- Would this locator survive DOM refactors that do not change user-visible behavior?
- Is this assertion using Playwright's retrying semantics?
- Is any explicit wait masking a real readiness problem?
- Does this code preserve per-scenario isolation?
- Is a new abstraction really needed, or does it bypass the existing `DifyWorld` + step-definition model?

View File

@ -64,7 +64,7 @@ export const useUpdateAccessMode = () => {
// Component only adds UI behavior.
updateAccessMode({ appId, mode }, {
onSuccess: () => toast.success('...'),
onSuccess: () => Toast.notify({ type: 'success', message: '...' }),
})
// Avoid putting invalidation knowledge in the component.
@ -114,7 +114,10 @@ try {
router.push(`/orders/${order.id}`)
}
catch (error) {
toast.error(error instanceof Error ? error.message : 'Unknown error')
Toast.notify({
type: 'error',
message: error instanceof Error ? error.message : 'Unknown error',
})
}
```

View File

@ -1 +0,0 @@
../../.agents/skills/e2e-cucumber-playwright

100
.github/dependabot.yml vendored
View File

@ -1,6 +1,106 @@
version: 2
updates:
- package-ecosystem: "pip"
directory: "/api"
open-pull-requests-limit: 10
schedule:
interval: "weekly"
groups:
flask:
patterns:
- "flask"
- "flask-*"
- "werkzeug"
- "gunicorn"
google:
patterns:
- "google-*"
- "googleapis-*"
opentelemetry:
patterns:
- "opentelemetry-*"
pydantic:
patterns:
- "pydantic"
- "pydantic-*"
llm:
patterns:
- "langfuse"
- "langsmith"
- "litellm"
- "mlflow*"
- "opik"
- "weave*"
- "arize*"
- "tiktoken"
- "transformers"
database:
patterns:
- "sqlalchemy"
- "psycopg2*"
- "psycogreen"
- "redis*"
- "alembic*"
storage:
patterns:
- "boto3*"
- "botocore*"
- "azure-*"
- "bce-*"
- "cos-python-*"
- "esdk-obs-*"
- "google-cloud-storage"
- "opendal"
- "oss2"
- "supabase*"
- "tos*"
vdb:
patterns:
- "alibabacloud*"
- "chromadb"
- "clickhouse-*"
- "clickzetta-*"
- "couchbase"
- "elasticsearch"
- "opensearch-py"
- "oracledb"
- "pgvect*"
- "pymilvus"
- "pymochow"
- "pyobvector"
- "qdrant-client"
- "intersystems-*"
- "tablestore"
- "tcvectordb"
- "tidb-vector"
- "upstash-*"
- "volcengine-*"
- "weaviate-*"
- "xinference-*"
- "mo-vector"
- "mysql-connector-*"
dev:
patterns:
- "coverage"
- "dotenv-linter"
- "faker"
- "lxml-stubs"
- "basedpyright"
- "ruff"
- "pytest*"
- "types-*"
- "boto3-stubs"
- "hypothesis"
- "pandas-stubs"
- "scipy-stubs"
- "import-linter"
- "celery-types"
- "mypy*"
- "pyrefly"
python-packages:
patterns:
- "*"
- package-ecosystem: "uv"
directory: "/api"
open-pull-requests-limit: 10

9
.github/labeler.yml vendored
View File

@ -1,10 +1,3 @@
web:
- changed-files:
- any-glob-to-any-file:
- 'web/**'
- 'packages/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- any-glob-to-any-file: 'web/**'

View File

@ -7,7 +7,6 @@
## Summary
<!-- Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. -->
<!-- If this PR was created by an automated agent, add `From <Tool Name>` as the final line of the description. Example: `From Codex`. -->
## Screenshots
@ -18,7 +17,7 @@
## Checklist
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
- [ ] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [ ] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [ ] I've updated the documentation accordingly.
- [ ] I ran `make lint && make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [x] I've updated the documentation accordingly.
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods

View File

@ -1,82 +0,0 @@
import { execFileSync } from 'node:child_process'
import fs from 'node:fs'
import path from 'node:path'
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const outputPath = process.env.I18N_CHANGES_OUTPUT_PATH || '/tmp/i18n-changes.json'
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const currentJson = readCurrentJson(fileStem)
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = currentJson || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: currentJson === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
outputPath,
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)

View File

@ -54,7 +54,7 @@ jobs:
run: uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Upload unit coverage data
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: api-coverage-unit
path: coverage-unit
@ -129,7 +129,7 @@ jobs:
api/tests/test_containers_integration_tests
- name: Upload integration coverage data
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: api-coverage-integration
path: coverage-integration

View File

@ -39,11 +39,9 @@ jobs:
with:
files: |
web/**
packages/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.npmrc
.nvmrc
- name: Check api inputs
if: github.event_name != 'merge_group'

View File

@ -65,7 +65,7 @@ jobs:
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Login to Docker Hub
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
@ -81,7 +81,7 @@ jobs:
- name: Build Docker image
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
with:
context: ${{ matrix.build_context }}
file: ${{ matrix.file }}
@ -101,7 +101,7 @@ jobs:
touch "/tmp/digests/${sanitized_digest}"
- name: Upload digest
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
@ -130,7 +130,7 @@ jobs:
merge-multiple: true
- name: Login to Docker Hub
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}

View File

@ -6,7 +6,12 @@ on:
- "main"
paths:
- api/Dockerfile
- web/docker/**
- web/Dockerfile
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .nvmrc
concurrency:
group: docker-build-${{ github.head_ref || github.run_id }}
@ -43,7 +48,7 @@ jobs:
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Build Docker Image
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
with:
push: false
context: ${{ matrix.context }}

View File

@ -65,11 +65,9 @@ jobs:
- 'docker/volumes/sandbox/conf/**'
web:
- 'web/**'
- 'packages/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- '.github/workflows/web-tests.yml'
- '.github/actions/setup-web/**'
@ -79,11 +77,9 @@ jobs:
- 'api/uv.lock'
- 'e2e/**'
- 'web/**'
- 'packages/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- 'docker/docker-compose.middleware.yaml'
- 'docker/middleware.env.example'
@ -92,7 +88,6 @@ jobs:
vdb:
- 'api/core/rag/datasource/**'
- 'api/tests/integration_tests/vdb/**'
- 'api/providers/vdb/*/tests/**'
- '.github/workflows/vdb-tests.yml'
- '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example'

View File

@ -21,7 +21,7 @@ jobs:
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
steps:
- name: Download pyrefly diff artifact
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
@ -49,7 +49,7 @@ jobs:
run: unzip -o pyrefly_diff.zip
- name: Post comment
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |

View File

@ -66,7 +66,7 @@ jobs:
echo ${{ github.event.pull_request.number }} > pr_number.txt
- name: Upload pyrefly diff
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: pyrefly_diff
path: |
@ -75,7 +75,7 @@ jobs:
- name: Comment PR with pyrefly diff
if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }}
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |

View File

@ -1,118 +0,0 @@
name: Comment with Pyrefly Type Coverage
on:
workflow_run:
workflows:
- Pyrefly Type Coverage
types:
- completed
permissions: {}
jobs:
comment:
name: Comment PR with type coverage
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
issues: write
pull-requests: write
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
steps:
- name: Checkout default branch (trusted code)
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
- name: Install dependencies
run: uv sync --project api --dev
- name: Download type coverage artifact
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const artifacts = await github.rest.actions.listWorkflowRunArtifacts({
owner: context.repo.owner,
repo: context.repo.repo,
run_id: ${{ github.event.workflow_run.id }},
});
const match = artifacts.data.artifacts.find((artifact) =>
artifact.name === 'pyrefly_type_coverage'
);
if (!match) {
throw new Error('pyrefly_type_coverage artifact not found');
}
const download = await github.rest.actions.downloadArtifact({
owner: context.repo.owner,
repo: context.repo.repo,
artifact_id: match.id,
archive_format: 'zip',
});
fs.writeFileSync('pyrefly_type_coverage.zip', Buffer.from(download.data));
- name: Unzip artifact
run: unzip -o pyrefly_type_coverage.zip
- name: Render coverage markdown from structured data
id: render
run: |
comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \
--base base_report.json \
< pr_report.json)"
{
echo "### Pyrefly Type Coverage"
echo ""
echo "$comment_body"
} > /tmp/type_coverage_comment.md
- name: Post comment
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
let prNumber = null;
try {
prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10);
} catch (err) {
const prs = context.payload.workflow_run.pull_requests || [];
if (prs.length > 0 && prs[0].number) {
prNumber = prs[0].number;
}
}
if (!prNumber) {
throw new Error('PR number not found in artifact or workflow_run payload');
}
// Update existing comment if one exists, otherwise create new
const { data: comments } = await github.rest.issues.listComments({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
});
const marker = '### Pyrefly Type Coverage';
const existing = comments.find(c => c.body.startsWith(marker));
if (existing) {
await github.rest.issues.updateComment({
comment_id: existing.id,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
} else {
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
}

View File

@ -1,120 +0,0 @@
name: Pyrefly Type Coverage
on:
pull_request:
paths:
- 'api/**/*.py'
permissions:
contents: read
jobs:
pyrefly-type-coverage:
runs-on: ubuntu-latest
permissions:
contents: read
issues: write
pull-requests: write
steps:
- name: Checkout PR branch
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
- name: Install dependencies
run: uv sync --project api --dev
- name: Run pyrefly report on PR branch
run: |
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_pr.tmp && \
mv /tmp/pyrefly_report_pr.tmp /tmp/pyrefly_report_pr.json || \
echo '{}' > /tmp/pyrefly_report_pr.json
- name: Save helper script from base branch
run: |
git show ${{ github.event.pull_request.base.sha }}:api/libs/pyrefly_type_coverage.py > /tmp/pyrefly_type_coverage.py 2>/dev/null \
|| cp api/libs/pyrefly_type_coverage.py /tmp/pyrefly_type_coverage.py
- name: Checkout base branch
run: git checkout ${{ github.base_ref }}
- name: Run pyrefly report on base branch
run: |
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_base.tmp && \
mv /tmp/pyrefly_report_base.tmp /tmp/pyrefly_report_base.json || \
echo '{}' > /tmp/pyrefly_report_base.json
- name: Generate coverage comparison
id: coverage
run: |
comment_body="$(uv run --directory api python /tmp/pyrefly_type_coverage.py \
--base /tmp/pyrefly_report_base.json \
< /tmp/pyrefly_report_pr.json)"
{
echo "### Pyrefly Type Coverage"
echo ""
echo "$comment_body"
} | tee -a "$GITHUB_STEP_SUMMARY" > /tmp/type_coverage_comment.md
# Save structured data for the fork-PR comment workflow
cp /tmp/pyrefly_report_pr.json pr_report.json
cp /tmp/pyrefly_report_base.json base_report.json
- name: Save PR number
run: |
echo ${{ github.event.pull_request.number }} > pr_number.txt
- name: Upload type coverage artifact
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: pyrefly_type_coverage
path: |
pr_report.json
base_report.json
pr_number.txt
- name: Comment PR with type coverage
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const marker = '### Pyrefly Type Coverage';
let body;
try {
body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
} catch {
body = `${marker}\n\n_Coverage report unavailable._`;
}
const prNumber = context.payload.pull_request.number;
// Update existing comment if one exists, otherwise create new
const { data: comments } = await github.rest.issues.listComments({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
});
const existing = comments.find(c => c.body.startsWith(marker));
if (existing) {
await github.rest.issues.updateComment({
comment_id: existing.id,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
} else {
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
}

View File

@ -23,8 +23,8 @@ jobs:
days-before-issue-stale: 15
days-before-issue-close: 3
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "Closed due to inactivity. If you have any questions, you can reopen it."
stale-pr-message: "Closed due to inactivity. If you have any questions, you can reopen it."
stale-issue-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
stale-pr-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
stale-issue-label: 'no-issue-activity'
stale-pr-label: 'no-pr-activity'
any-of-labels: '🌚 invalid,🙋‍♂️ question,wont-fix,no-issue-activity,no-pr-activity,💪 enhancement,🤔 cant-reproduce,🙏 help wanted'
any-of-labels: 'duplicate,question,invalid,wontfix,no-issue-activity,no-pr-activity,enhancement,cant-reproduce,help-wanted'

View File

@ -77,11 +77,9 @@ jobs:
with:
files: |
web/**
packages/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.npmrc
.nvmrc
.github/workflows/style.yml
.github/actions/setup-web/**
@ -151,7 +149,7 @@ jobs:
.editorconfig
- name: Super-linter
uses: super-linter/super-linter/slim@9e863354e3ff62e0727d37183162c4a88873df41 # v8.6.0
uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning

View File

@ -9,7 +9,6 @@ on:
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .npmrc
concurrency:
group: sdk-tests-${{ github.head_ref || github.run_id }}

View File

@ -68,7 +68,89 @@ jobs:
" web/i18n-config/languages.ts | sed 's/[[:space:]]*$//')
generate_changes_json() {
node .github/scripts/generate-i18n-changes.mjs
node <<'NODE'
const { execFileSync } = require('node:child_process')
const fs = require('node:fs')
const path = require('node:path')
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch (error) {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const currentJson = readCurrentJson(fileStem)
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = currentJson || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: currentJson === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
'/tmp/i18n-changes.json',
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)
NODE
}
if [ "${{ github.event_name }}" = "repository_dispatch" ]; then
@ -158,7 +240,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93
uses: anthropics/claude-code-action@88c168b39e7e64da0286d812b6e9fbebb6708185 # v1.0.82
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}
@ -188,7 +270,7 @@ jobs:
Tool rules:
- Use Read for repository files.
- Use Edit for JSON updates.
- Use Bash only for `vp`.
- Use Bash only for `pnpm`.
- Do not use Bash for `git`, `gh`, or branch management.
Required execution plan:
@ -210,7 +292,7 @@ jobs:
- Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate.
- If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth.
4. Run a scoped pre-check before editing:
- `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- Use this command as the source of truth for missing and extra keys inside the current scope.
5. Apply translations.
- For every target language and scoped file:
@ -218,19 +300,19 @@ jobs:
- If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed.
- ADD missing keys.
- UPDATE stale translations when the English value changed.
- DELETE removed keys. Prefer `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
- DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
- Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names.
- Match the existing terminology and register used by each locale.
- Prefer one Edit per file when stable, but prioritize correctness over batching.
6. Verify only the edited files.
- Run `vp run dify-web#lint:fix --quiet -- <relative edited i18n file paths under web/>`
- Run `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- <relative edited i18n file paths>`
- Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- If verification fails, fix the remaining problems before continuing.
7. Stop after the scoped locale files are updated and verification passes.
- Do not create branches, commits, or pull requests.
claude_args: |
--max-turns 120
--allowedTools "Read,Write,Edit,Bash(vp *),Bash(vp:*),Glob,Grep"
--allowedTools "Read,Write,Edit,Bash(pnpm *),Bash(pnpm:*),Glob,Grep"
- name: Prepare branch metadata
id: pr_meta
@ -272,7 +354,6 @@ jobs:
- name: Create or update translation PR
if: steps.pr_meta.outputs.has_changes == 'true'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }}
FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }}
TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }}
@ -321,8 +402,8 @@ jobs:
'',
'## Verification',
'',
`- \`vp run dify-web#i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
`- \`vp run dify-web#lint:fix --quiet -- <edited i18n files under web/>\``,
`- \`pnpm --dir web run i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
`- \`pnpm --dir web lint:fix --quiet -- <edited i18n files>\``,
'',
'## Notes',
'',

View File

@ -42,7 +42,88 @@ jobs:
fi
export BASE_SHA HEAD_SHA CHANGED_FILES
node .github/scripts/generate-i18n-changes.mjs
node <<'NODE'
const { execFileSync } = require('node:child_process')
const fs = require('node:fs')
const path = require('node:path')
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch (error) {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = readCurrentJson(fileStem) || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: readCurrentJson(fileStem) === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
'/tmp/i18n-changes.json',
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)
NODE
if [ -n "$CHANGED_FILES" ]; then
echo "has_changes=true" >> "$GITHUB_OUTPUT"
@ -56,7 +137,7 @@ jobs:
- name: Trigger i18n sync workflow
if: steps.detect.outputs.has_changes == 'true'
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
env:
BASE_SHA: ${{ steps.detect.outputs.base_sha }}
HEAD_SHA: ${{ steps.detect.outputs.head_sha }}

View File

@ -36,7 +36,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -89,7 +89,7 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@ -81,12 +81,12 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores
run: |
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
api/providers/vdb/vdb-chroma/tests/integration_tests \
api/providers/vdb/vdb-pgvector/tests/integration_tests \
api/providers/vdb/vdb-qdrant/tests/integration_tests \
api/providers/vdb/vdb-weaviate/tests/integration_tests
api/tests/integration_tests/vdb/chroma \
api/tests/integration_tests/vdb/pgvector \
api/tests/integration_tests/vdb/qdrant \
api/tests/integration_tests/vdb/weaviate

View File

@ -53,7 +53,7 @@ jobs:
- name: Upload Cucumber report
if: ${{ !cancelled() }}
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: cucumber-report
path: e2e/cucumber-report
@ -61,7 +61,7 @@ jobs:
- name: Upload E2E logs
if: ${{ !cancelled() }}
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: e2e-logs
path: e2e/.logs

View File

@ -43,7 +43,7 @@ jobs:
- name: Upload blob report
if: ${{ !cancelled() }}
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: blob-report-${{ matrix.shardIndex }}
path: web/.vitest-reports/*

2
.gitignore vendored
View File

@ -212,7 +212,7 @@ api/.vscode
# pnpm
/.pnpm-store
node_modules
/node_modules
.vite-hooks/_
# plugin migrate

1
.npmrc
View File

@ -1 +0,0 @@
save-exact=true

View File

@ -81,18 +81,38 @@ if $web_modified; then
if $web_ts_modified; then
echo "Running TypeScript type-check:tsgo"
if ! npm run type-check:tsgo; then
echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors."
if ! pnpm run type-check:tsgo; then
echo "Type check failed. Please run 'pnpm run type-check:tsgo' to fix the errors."
exit 1
fi
else
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
fi
echo "Running knip"
if ! npm run knip; then
echo "Knip check failed. Please run 'npm run knip' to fix the errors."
exit 1
echo "Running unit tests check"
modified_files=$(git diff --cached --name-only -- utils | grep -v '\.spec\.ts$' || true)
if [ -n "$modified_files" ]; then
for file in $modified_files; do
test_file="${file%.*}.spec.ts"
echo "Checking for test file: $test_file"
# check if the test file exists
if [ -f "../$test_file" ]; then
echo "Detected changes in $file, running corresponding unit tests..."
pnpm run test "../$test_file"
if [ $? -ne 0 ]; then
echo "Unit tests failed. Please fix the errors before committing."
exit 1
fi
echo "Unit tests for $file passed."
else
echo "Warning: $file does not have a corresponding test file."
fi
done
echo "All unit tests for modified web/utils files have passed."
fi
cd ../

View File

@ -97,3 +97,8 @@ Feel free to reach out if you encounter any issues during the setup process.
## Getting Help
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
## Automated Agent Contributions
> [!NOTE]
> If you are an automated agent just add 🤖🤖🤖 to the end of the PR title to opt-in.

View File

@ -57,9 +57,6 @@ REDIS_SSL_CERTFILE=
REDIS_SSL_KEYFILE=
# Path to client private key file for SSL authentication
REDIS_DB=0
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
# Leave empty to preserve current unprefixed behavior.
REDIS_KEY_PREFIX=
# redis Sentinel configuration.
REDIS_USE_SENTINEL=false
@ -74,13 +71,6 @@ REDIS_USE_CLUSTERS=false
REDIS_CLUSTERS=
REDIS_CLUSTERS_PASSWORD=
REDIS_RETRY_RETRIES=3
REDIS_RETRY_BACKOFF_BASE=1.0
REDIS_RETRY_BACKOFF_CAP=10.0
REDIS_SOCKET_TIMEOUT=5.0
REDIS_SOCKET_CONNECT_TIMEOUT=5.0
REDIS_HEALTH_CHECK_INTERVAL=30
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
CELERY_BACKEND=redis
@ -112,7 +102,6 @@ S3_BUCKET_NAME=your-bucket-name
S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key
S3_REGION=your-region
S3_ADDRESS_STYLE=auto
# Workflow run and Conversation archive storage (S3-compatible)
ARCHIVE_STORAGE_ENABLED=false

View File

@ -69,6 +69,8 @@ ignore = [
"FURB152", # math-constant
"UP007", # non-pep604-annotation
"UP032", # f-string
"UP045", # non-pep604-annotation-optional
"B005", # strip-with-multi-characters
"B006", # mutable-argument-default
"B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg
@ -82,6 +84,7 @@ ignore = [
"SIM102", # collapsible-if
"SIM103", # needless-bool
"SIM105", # suppressible-exception
"SIM107", # return-in-try-except-finally
"SIM108", # if-else-block-instead-of-if-exp
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
@ -90,16 +93,35 @@ ignore = [
]
[lint.per-file-ignores]
"__init__.py" = [
"F401", # unused-import
"F811", # redefined-while-unused
]
"configs/*" = [
"N802", # invalid-function-name
]
"graphon/model_runtime/callbacks/base_callback.py" = ["T201"]
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
"libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name
]
"tests/*" = [
"F811", # redefined-while-unused
"T201", # allow print in tests,
"S110", # allow ignoring exceptions in tests code (currently)
]
"controllers/console/explore/trial.py" = ["TID251"]
"controllers/console/human_input_form.py" = ["TID251"]
"controllers/web/human_input_form.py" = ["TID251"]
[lint.pyflakes]
allowed-unused-imports = [
"tests.integration_tests",
"tests.unit_tests",
]
[lint.flake8-tidy-imports]
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
msg = "Use Pydantic payload/query models instead of reqparse."

View File

@ -21,9 +21,8 @@ RUN apt-get update \
# for building gmpy2
libmpfr-dev libmpc-dev
# Install Python dependencies (workspace members under providers/vdb/)
# Install Python dependencies
COPY pyproject.toml uv.lock ./
COPY providers ./providers
RUN uv sync --locked --no-dev
# production stage

View File

@ -1,18 +0,0 @@
# This module provides a lightweight Celery instance for use in Docker health checks.
# Unlike celery_entrypoint.py, this does NOT import app.py and therefore avoids
# initializing all Flask extensions (DB, Redis, storage, blueprints, etc.).
# Using this module keeps the health check fast and low-cost.
from celery import Celery
from configs import dify_config
from extensions.ext_celery import get_celery_broker_transport_options, get_celery_ssl_options
celery = Celery(broker=dify_config.CELERY_BROKER_URL)
broker_transport_options = get_celery_broker_transport_options()
if broker_transport_options:
celery.conf.update(broker_transport_options=broker_transport_options)
ssl_options = get_celery_ssl_options()
if ssl_options:
celery.conf.update(broker_use_ssl=ssl_options)

View File

@ -2,6 +2,7 @@ import base64
import secrets
import click
from sqlalchemy.orm import sessionmaker
from constants.languages import languages
from extensions.ext_database import db
@ -24,31 +25,30 @@ def reset_password(email, new_password, password_confirm):
return
normalized_email = email.strip().lower()
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account = db.session.merge(account)
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@click.command("reset-email", help="Reset the account email.")
@ -65,22 +65,21 @@ def reset_email(email, new_email, email_confirm):
return
normalized_new_email = new_email.strip().lower()
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
email_validate(normalized_new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
try:
email_validate(normalized_new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
account = db.session.merge(account)
account.email = normalized_new_email
db.session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))
account.email = normalized_new_email
click.echo(click.style("Email updated successfully.", fg="green"))
@click.command("create-tenant", help="Create account and tenant.")

View File

@ -1,7 +1,7 @@
import datetime
import logging
import time
from typing import TypedDict
from typing import Any
import click
import sqlalchemy as sa
@ -503,19 +503,7 @@ def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
return [row[0] for row in result]
class _AppOrphanCounts(TypedDict):
variables: int
files: int
class OrphanedDraftVariableStatsDict(TypedDict):
total_orphaned_variables: int
total_orphaned_files: int
orphaned_app_count: int
orphaned_by_app: dict[str, _AppOrphanCounts]
def _count_orphaned_draft_variables() -> OrphanedDraftVariableStatsDict:
def _count_orphaned_draft_variables() -> dict[str, Any]:
"""
Count orphaned draft variables by app, including associated file counts.
@ -538,7 +526,7 @@ def _count_orphaned_draft_variables() -> OrphanedDraftVariableStatsDict:
with db.engine.connect() as conn:
result = conn.execute(sa.text(variables_query))
orphaned_by_app: dict[str, _AppOrphanCounts] = {}
orphaned_by_app = {}
total_files = 0
for row in result:

View File

@ -341,10 +341,11 @@ def add_qdrant_index(field: str):
click.echo(click.style("No dataset collection bindings found.", fg="red"))
return
import qdrant_client
from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
for binding in bindings:
if dify_config.QDRANT_URL is None:
raise ValueError("Qdrant URL is required.")

View File

@ -1,5 +1,5 @@
import os
from typing import Any, Literal, TypedDict
from typing import Any, Literal
from urllib.parse import parse_qsl, quote_plus
from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
@ -107,17 +107,6 @@ class KeywordStoreConfig(BaseSettings):
)
class SQLAlchemyEngineOptionsDict(TypedDict):
pool_size: int
max_overflow: int
pool_recycle: int
pool_pre_ping: bool
connect_args: dict[str, str]
pool_use_lifo: bool
pool_reset_on_return: None
pool_timeout: int
class DatabaseConfig(BaseSettings):
# Database type selector
DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field(
@ -160,16 +149,6 @@ class DatabaseConfig(BaseSettings):
default="",
)
DB_SESSION_TIMEZONE_OVERRIDE: str = Field(
description=(
"PostgreSQL session timezone override injected via startup options."
" Default is 'UTC' for out-of-the-box consistency."
" Set to empty string to disable app-level timezone injection, for example when using RDS Proxy"
" together with a database-side default timezone."
),
default="UTC",
)
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
@ -230,22 +209,21 @@ class DatabaseConfig(BaseSettings):
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> SQLAlchemyEngineOptionsDict:
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
# Parse DB_EXTRAS for 'options'
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
options = db_extras_dict.get("options", "")
connect_args: dict[str, str] = {}
connect_args = {}
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
merged_options = options.strip()
session_timezone_override = self.DB_SESSION_TIMEZONE_OVERRIDE.strip()
if session_timezone_override:
timezone_opt = f"-c timezone={session_timezone_override}"
merged_options = f"{merged_options} {timezone_opt}".strip() if merged_options else timezone_opt
if merged_options:
connect_args = {"options": merged_options}
timezone_opt = "-c timezone=UTC"
if options:
merged_options = f"{options} {timezone_opt}"
else:
merged_options = timezone_opt
connect_args = {"options": merged_options}
result: SQLAlchemyEngineOptionsDict = {
return {
"pool_size": self.SQLALCHEMY_POOL_SIZE,
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
@ -255,7 +233,6 @@ class DatabaseConfig(BaseSettings):
"pool_reset_on_return": None,
"pool_timeout": self.SQLALCHEMY_POOL_TIMEOUT,
}
return result
class CeleryConfig(DatabaseConfig):

View File

@ -32,11 +32,6 @@ class RedisConfig(BaseSettings):
default=0,
)
REDIS_KEY_PREFIX: str = Field(
description="Optional global prefix for Redis keys, topics, and transport artifacts",
default="",
)
REDIS_USE_SSL: bool = Field(
description="Enable SSL/TLS for the Redis connection",
default=False,
@ -122,37 +117,6 @@ class RedisConfig(BaseSettings):
default=None,
)
REDIS_RETRY_RETRIES: NonNegativeInt = Field(
description="Maximum number of retries per Redis command on "
"transient failures (ConnectionError, TimeoutError, socket.timeout)",
default=3,
)
REDIS_RETRY_BACKOFF_BASE: PositiveFloat = Field(
description="Base delay in seconds for exponential backoff between retries",
default=1.0,
)
REDIS_RETRY_BACKOFF_CAP: PositiveFloat = Field(
description="Maximum backoff delay in seconds between retries",
default=10.0,
)
REDIS_SOCKET_TIMEOUT: PositiveFloat | None = Field(
description="Socket timeout in seconds for Redis read/write operations",
default=5.0,
)
REDIS_SOCKET_CONNECT_TIMEOUT: PositiveFloat | None = Field(
description="Socket timeout in seconds for Redis connection establishment",
default=5.0,
)
REDIS_HEALTH_CHECK_INTERVAL: NonNegativeInt = Field(
description="Interval in seconds between Redis connection health checks (0 to disable)",
default=30,
)
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
@classmethod
def _empty_string_to_none_for_max_conns(cls, v):

View File

@ -1,3 +1,4 @@
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from pydantic import Field
from pydantic_settings import BaseSettings
@ -41,17 +42,17 @@ class HologresConfig(BaseSettings):
default="public",
)
HOLOGRES_TOKENIZER: str = Field(
HOLOGRES_TOKENIZER: TokenizerType = Field(
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
default="jieba",
)
HOLOGRES_DISTANCE_METHOD: str = Field(
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
default="Cosine",
)
HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field(
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
default="rabitq",
)

View File

@ -1,7 +1,5 @@
"""Configuration for InterSystems IRIS vector database."""
from typing import Any
from pydantic import Field, PositiveInt, model_validator
from pydantic_settings import BaseSettings
@ -66,7 +64,7 @@ class IrisVectorConfig(BaseSettings):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
def validate_config(cls, values: dict) -> dict:
"""Validate IRIS configuration values.
Args:

View File

@ -10,7 +10,7 @@ import threading
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Protocol, final, runtime_checkable
from typing import Any, Protocol, TypeVar, final, runtime_checkable
from pydantic import BaseModel
@ -188,6 +188,8 @@ class ExecutionContextBuilder:
_capturer: Callable[[], IExecutionContext] | None = None
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
T = TypeVar("T", bound=BaseModel)
class ContextProviderNotFoundError(KeyError):
"""Raised when a tenant-scoped context provider is missing."""

View File

@ -1,4 +1,7 @@
from contextvars import ContextVar
from typing import Generic, TypeVar
T = TypeVar("T")
class HiddenValue:
@ -8,7 +11,7 @@ class HiddenValue:
_default = HiddenValue()
class RecyclableContextVar[T]:
class RecyclableContextVar(Generic[T]):
"""
RecyclableContextVar is a wrapper around ContextVar
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now

View File

@ -1,104 +0,0 @@
from typing import Any, Literal
from uuid import UUID
from pydantic import BaseModel, Field, model_validator
from libs.helper import UUIDStrOrEmpty
# --- Conversation schemas ---
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
# --- Message schemas ---
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty = Field(description="Conversation UUID")
first_id: UUIDStrOrEmpty | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
# --- Saved message schemas ---
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
# --- Workflow schemas ---
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
# --- Dataset schemas ---
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkUpdatePayload(BaseModel):
content: str
class DocumentBatchDownloadZipPayload(BaseModel):
"""Request payload for bulk downloading documents as a zip archive."""
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
class MetadataUpdatePayload(BaseModel):
name: str
# --- Audio schemas ---
class TextToAudioPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
voice: str | None = Field(default=None, description="Voice to use for TTS")
text: str | None = Field(default=None, description="Text to convert to audio")
streaming: bool | None = Field(default=None, description="Enable streaming response")

View File

@ -1,14 +1,14 @@
from __future__ import annotations
from typing import Any
from typing import Any, TypeAlias
from graphon.file import helpers as file_helpers
from pydantic import BaseModel, ConfigDict, computed_field
from models.model import IconType
type JSONValue = str | int | float | bool | None | dict[str, Any] | list[Any]
type JSONObject = dict[str, Any]
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
JSONObject: TypeAlias = dict[str, Any]
class SystemParameters(BaseModel):

View File

@ -2,7 +2,7 @@ import csv
import io
from collections.abc import Callable
from functools import wraps
from typing import cast
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource
@ -18,7 +18,10 @@ from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from services.billing_service import BillingService, LangContentDict
from services.billing_service import BillingService
P = ParamSpec("P")
R = TypeVar("R")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -69,9 +72,9 @@ console_ns.schema_model(
)
def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def admin_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")
@ -329,7 +332,7 @@ class UpsertNotificationApi(Resource):
def post(self):
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
result = BillingService.upsert_notification(
contents=[cast(LangContentDict, c.model_dump()) for c in payload.contents],
contents=[c.model_dump() for c in payload.contents],
frequency=payload.frequency,
status=payload.status,
notification_id=payload.notification_id,

View File

@ -1,16 +1,12 @@
from datetime import datetime
import flask_restx
from flask_restx import Resource
from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from pydantic import field_validator
from sqlalchemy import delete, func, select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.enums import ApiTokenType
@ -20,31 +16,21 @@ from services.api_token_service import ApiTokenCache
from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = {
"id": fields.String,
"type": fields.String,
"token": fields.String,
"last_used_at": TimestampField,
"created_at": TimestampField,
}
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
class ApiKeyItem(ResponseModel):
id: str
type: str
token: str
last_used_at: int | None = None
created_at: int | None = None
@field_validator("last_used_at", "created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class ApiKeyList(ResponseModel):
data: list[ApiKeyItem]
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
api_key_list_model = console_ns.model(
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
)
def _get_resource(resource_id, tenant_id, resource_model):
@ -68,6 +54,7 @@ class BaseApiKeyListResource(Resource):
token_prefix: str | None = None
max_keys = 10
@marshal_with(api_key_list_model)
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
@ -79,8 +66,9 @@ class BaseApiKeyListResource(Resource):
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
)
).all()
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
return {"items": keys}
@marshal_with(api_key_item_model)
@edit_permission_required
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
@ -112,7 +100,7 @@ class BaseApiKeyListResource(Resource):
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
return api_token, 201
class BaseApiKeyResource(Resource):
@ -159,7 +147,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("get_app_api_keys")
@console_ns.doc(description="Get all API keys for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
@console_ns.response(200, "Success", api_key_list_model)
def get(self, resource_id): # type: ignore
"""Get all API keys for an app"""
return super().get(resource_id)
@ -167,7 +155,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("create_app_api_key")
@console_ns.doc(description="Create a new API key for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(201, "API key created successfully", api_key_item_model)
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore
"""Create a new API key for an app"""
@ -199,7 +187,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get all API keys for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
@console_ns.response(200, "Success", api_key_list_model)
def get(self, resource_id): # type: ignore
"""Get all API keys for a dataset"""
return super().get(resource_id)
@ -207,7 +195,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("create_dataset_api_key")
@console_ns.doc(description="Create a new API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(201, "API key created successfully", api_key_item_model)
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore
"""Create a new API key for a dataset"""

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateArgs, AdvancedPromptTemplateService
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class AdvancedPromptTemplateQuery(BaseModel):
@ -35,10 +35,5 @@ class AdvancedPromptTemplateList(Resource):
@account_initialization_required
def get(self):
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
prompt_args: AdvancedPromptTemplateArgs = {
"app_mode": args.app_mode,
"model_mode": args.model_mode,
"model_name": args.model_name,
"has_context": args.has_context,
}
return AdvancedPromptTemplateService.get_prompt(prompt_args)
return AdvancedPromptTemplateService.get_prompt(args.model_dump())

View File

@ -25,13 +25,7 @@ from fields.annotation_fields import (
)
from libs.helper import uuid_value
from libs.login import login_required
from services.annotation_service import (
AppAnnotationService,
EnableAnnotationArgs,
UpdateAnnotationArgs,
UpdateAnnotationSettingArgs,
UpsertAnnotationArgs,
)
from services.annotation_service import AppAnnotationService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -126,12 +120,7 @@ class AnnotationReplyActionApi(Resource):
args = AnnotationReplyPayload.model_validate(console_ns.payload)
match action:
case "enable":
enable_args: EnableAnnotationArgs = {
"score_threshold": args.score_threshold,
"embedding_provider_name": args.embedding_provider_name,
"embedding_model_name": args.embedding_model_name,
}
result = AppAnnotationService.enable_app_annotation(enable_args, app_id)
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@ -172,8 +161,7 @@ class AppAnnotationSettingUpdateApi(Resource):
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args)
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
return result, 200
@ -249,16 +237,8 @@ class AnnotationApi(Resource):
def post(self, app_id):
app_id = str(app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload)
upsert_args: UpsertAnnotationArgs = {}
if args.answer is not None:
upsert_args["answer"] = args.answer
if args.content is not None:
upsert_args["content"] = args.content
if args.message_id is not None:
upsert_args["message_id"] = args.message_id
if args.question is not None:
upsert_args["question"] = args.question
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id)
data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@ -335,12 +315,9 @@ class AnnotationUpdateDeleteApi(Resource):
app_id = str(app_id)
annotation_id = str(annotation_id)
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
update_args: UpdateAnnotationArgs = {}
if args.answer is not None:
update_args["answer"] = args.answer
if args.question is not None:
update_args["question"] = args.question
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id)
annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required

View File

@ -1,14 +1,15 @@
import logging
import uuid
from datetime import datetime
from typing import Any, Literal
from typing import Any, Literal, TypeAlias
from flask import request
from flask_restx import Resource
from graphon.enums import WorkflowExecutionStatus
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
from graphon.file import helpers as file_helpers
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest
from controllers.common.helpers import FileInfo
@ -25,26 +26,25 @@ from controllers.console.wraps import (
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.trigger.constants import TRIGGER_NODE_TYPES
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import build_icon_url
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
from services.app_dsl_service import AppDslService
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.entities.dsl_entities import ImportMode, ImportStatus
from services.entities.knowledge_entities.knowledge_entities import (
DataSource,
InfoList,
NotionIcon,
NotionInfo,
NotionPage,
PreProcessingRule,
RerankingModel,
Rule,
Segmentation,
WebsiteInfo,
WeightKeywordSetting,
WeightModel,
@ -152,7 +152,17 @@ class AppTracePayload(BaseModel):
return value
type JSONValue = Any
JSONValue: TypeAlias = Any
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
def _to_timestamp(value: datetime | int | None) -> int | None:
@ -161,6 +171,15 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
return value
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
if icon is None or icon_type is None:
return None
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
if icon_type_value.lower() != IconType.IMAGE:
return None
return file_helpers.get_signed_file_url(icon)
class Tag(ResponseModel):
id: str
name: str
@ -283,7 +302,7 @@ class Site(ResponseModel):
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return build_icon_url(self.icon_type, self.icon)
return _build_icon_url(self.icon_type, self.icon)
@field_validator("icon_type", mode="before")
@classmethod
@ -333,7 +352,7 @@ class AppPartial(ResponseModel):
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return build_icon_url(self.icon_type, self.icon)
return _build_icon_url(self.icon_type, self.icon)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
@ -381,7 +400,7 @@ class AppDetailWithSite(AppDetail):
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return build_icon_url(self.icon_type, self.icon)
return _build_icon_url(self.icon_type, self.icon)
class AppPagination(ResponseModel):
@ -623,7 +642,7 @@ class AppCopyApi(Resource):
args = CopyAppPayload.model_validate(console_ns.payload or {})
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
result = import_service.import_app(
@ -636,13 +655,6 @@ class AppCopyApi(Resource):
icon=args.icon,
icon_background=args.icon_background,
)
if result.status == ImportStatus.FAILED:
session.rollback()
return result.model_dump(mode="json"), 400
if result.status == ImportStatus.PENDING:
session.rollback()
return result.model_dump(mode="json"), 202
session.commit()
# Inherit web app permission from original app
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:

View File

@ -1,8 +1,7 @@
from flask_restx import Resource
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
@ -11,15 +10,34 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import (
app_import_check_dependencies_fields,
app_import_fields,
leaked_dependency_fields,
)
from libs.login import current_account_with_tenant, login_required
from models.model import App
from services.app_dsl_service import AppDslService, Import
from services.app_dsl_service import AppDslService, ImportStatus
from services.enterprise.enterprise_service import EnterpriseService
from services.entities.dsl_entities import CheckDependenciesResult, ImportStatus
from services.feature_service import FeatureService
from .. import console_ns
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
leaked_dependency_model = console_ns.model("LeakedDependency", leaked_dependency_fields)
app_import_model = console_ns.model("AppImport", app_import_fields)
# For nested models, need to replace nested dict with registered model
app_import_check_dependencies_fields_copy = app_import_check_dependencies_fields.copy()
app_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(fields.Nested(leaked_dependency_model))
app_import_check_dependencies_model = console_ns.model(
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppImportPayload(BaseModel):
mode: str = Field(..., description="Import mode")
@ -33,18 +51,18 @@ class AppImportPayload(BaseModel):
app_id: str | None = Field(None)
register_schema_models(console_ns, AppImportPayload, Import, CheckDependenciesResult)
console_ns.schema_model(
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/imports")
class AppImportApi(Resource):
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
@console_ns.response(200, "Import completed", console_ns.models[Import.__name__])
@console_ns.response(202, "Import pending confirmation", console_ns.models[Import.__name__])
@console_ns.response(400, "Import failed", console_ns.models[Import.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_model)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
@ -52,9 +70,8 @@ class AppImportApi(Resource):
current_user, _ = current_account_with_tenant()
args = AppImportPayload.model_validate(console_ns.payload)
# AppDslService performs internal commits for some creation paths, so use a plain
# Session here instead of nesting it inside sessionmaker(...).begin().
with Session(db.engine, expire_on_commit=False) as session:
# Create service with session
with sessionmaker(db.engine).begin() as session:
import_service = AppDslService(session)
# Import app
account = current_user
@ -70,45 +87,35 @@ class AppImportApi(Resource):
icon_background=args.icon_background,
app_id=args.app_id,
)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
# update web app setting as private
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
# Return appropriate status code based on result
status = result.status
match status:
case ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
case ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
return result.model_dump(mode="json"), 200
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@console_ns.route("/apps/imports/<string:import_id>/confirm")
class AppImportConfirmApi(Resource):
@console_ns.response(200, "Import confirmed", console_ns.models[Import.__name__])
@console_ns.response(400, "Import failed", console_ns.models[Import.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_model)
@edit_permission_required
def post(self, import_id):
# Check user role first
current_user, _ = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
# Create service with session
with sessionmaker(db.engine).begin() as session:
import_service = AppDslService(session)
# Confirm import
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED:
@ -118,14 +125,14 @@ class AppImportConfirmApi(Resource):
@console_ns.route("/apps/imports/<string:app_id>/check-dependencies")
class AppImportCheckDependenciesApi(Resource):
@console_ns.response(200, "Dependencies checked", console_ns.models[CheckDependenciesResult.__name__])
@setup_required
@login_required
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_model)
@edit_permission_required
def get(self, app_model: App):
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(db.engine).begin() as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)

View File

@ -1,86 +1,44 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields._value_type_serializer import serialize_value_type
from fields.base import ResponseModel
from fields.conversation_variable_fields import (
conversation_variable_fields,
paginated_conversation_variable_fields,
)
from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ConversationVariablesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID to filter variables")
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
console_ns.schema_model(
ConversationVariablesQuery.__name__,
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
class ConversationVariableResponse(ResponseModel):
id: str
name: str
value_type: str
value: str | None = None
description: str | None = None
created_at: int | None = None
updated_at: int | None = None
@field_validator("value_type", mode="before")
@classmethod
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
if isinstance(value, str):
return value
try:
return serialize_value_type(value)
except Exception:
return serialize_value_type({"value_type": value})
@field_validator("value", mode="before")
@classmethod
def _normalize_value(cls, value: Any | None) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(value)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class PaginatedConversationVariableResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[ConversationVariableResponse]
register_schema_models(
console_ns,
ConversationVariablesQuery,
ConversationVariableResponse,
PaginatedConversationVariableResponse,
# For nested models, need to replace nested dict with registered model
paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy()
paginated_conversation_variable_fields_copy["data"] = fields.List(
fields.Nested(conversation_variable_model), attribute="data"
)
paginated_conversation_variable_model = console_ns.model(
"PaginatedConversationVariable", paginated_conversation_variable_fields_copy
)
@ -90,15 +48,12 @@ class ConversationVariablesApi(Resource):
@console_ns.doc(description="Get conversation variables for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
@console_ns.response(
200,
"Conversation variables retrieved successfully",
console_ns.models[PaginatedConversationVariableResponse.__name__],
)
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_model)
def get(self, app_model):
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
@ -117,22 +72,17 @@ class ConversationVariablesApi(Resource):
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
rows = session.scalars(stmt).all()
response = PaginatedConversationVariableResponse.model_validate(
{
"page": page,
"limit": page_size,
"total": len(rows),
"has_more": False,
"data": [
ConversationVariableResponse.model_validate(
{
"created_at": row.created_at,
"updated_at": row.updated_at,
**row.to_variable().model_dump(),
}
)
for row in rows
],
}
)
return response.model_dump(mode="json")
return {
"page": page,
"limit": page_size,
"total": len(rows),
"has_more": False,
"data": [
{
"created_at": row.created_at,
"updated_at": row.updated_at,
**row.to_variable().model_dump(),
}
for row in rows
],
}

View File

@ -1,68 +1,39 @@
import json
from datetime import datetime
from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.app_fields import app_server_fields
from libs.login import current_account_with_tenant, login_required
from models.enums import AppMCPServerStatus
from models.model import AppMCPServer
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
# Register model for flask_restx to avoid dict type issues in Swagger
app_server_model = console_ns.model("AppServer", app_server_fields)
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
parameters: dict = Field(..., description="Server parameters configuration")
class MCPServerUpdatePayload(BaseModel):
id: str = Field(..., description="Server ID")
description: str | None = Field(default=None, description="Server description")
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
parameters: dict = Field(..., description="Server parameters configuration")
status: str | None = Field(default=None, description="Server status")
class AppMCPServerResponse(ResponseModel):
id: str
name: str
server_code: str
description: str
status: str
parameters: dict[str, Any] | list[Any] | str
created_at: int | None = None
updated_at: int | None = None
@field_validator("parameters", mode="before")
@classmethod
def _parse_json_string(cls, value: Any) -> Any:
if isinstance(value, str):
try:
return json.loads(value)
except (json.JSONDecodeError, TypeError):
return value
return value
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
register_schema_models(console_ns, MCPServerCreatePayload, MCPServerUpdatePayload, AppMCPServerResponse)
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/apps/<uuid:app_id>/server")
@ -70,27 +41,27 @@ class AppMCPServerController(Resource):
@console_ns.doc("get_app_mcp_server")
@console_ns.doc(description="Get MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(200, "MCP server configuration retrieved successfully", app_server_model)
@login_required
@account_initialization_required
@setup_required
@get_app_model
@marshal_with(app_server_model)
def get(self, app_model):
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
if server is None:
return {}
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
return server
@console_ns.doc("create_app_mcp_server")
@console_ns.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
@console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions")
@account_initialization_required
@get_app_model
@login_required
@setup_required
@marshal_with(app_server_model)
@edit_permission_required
def post(self, app_model):
_, current_tenant_id = current_account_with_tenant()
@ -111,19 +82,20 @@ class AppMCPServerController(Resource):
)
db.session.add(server)
db.session.commit()
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
return server
@console_ns.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
@console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@get_app_model
@login_required
@setup_required
@account_initialization_required
@marshal_with(app_server_model)
@edit_permission_required
def put(self, app_model):
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
@ -146,7 +118,7 @@ class AppMCPServerController(Resource):
except ValueError:
raise ValueError("Invalid status")
db.session.commit()
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
return server
@console_ns.route("/apps/<uuid:server_id>/server/refresh")
@ -154,12 +126,13 @@ class AppMCPServerRefreshController(Resource):
@console_ns.doc("refresh_app_mcp_server")
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
@console_ns.doc(params={"server_id": "Server ID"})
@console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(200, "MCP server refreshed successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_server_model)
@edit_permission_required
def get(self, server_id):
_, current_tenant_id = current_account_with_tenant()
@ -172,4 +145,4 @@ class AppMCPServerRefreshController(Resource):
raise NotFound()
server.server_code = AppMCPServer.generate_server_code(16)
db.session.commit()
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
return server

View File

@ -1,15 +1,13 @@
import logging
from datetime import datetime
from typing import Literal
from flask import request
from flask_restx import Resource
from flask_restx import Resource, fields, marshal_with
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, func, select
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload as _MessageFeedbackPayloadBase
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
@ -26,21 +24,10 @@ from controllers.console.wraps import (
setup_required,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.conversation_fields import (
AgentThought,
ConversationAnnotation,
ConversationAnnotationHitHistory,
Feedback,
JSONValue,
MessageFile,
format_files_contained,
to_timestamp,
)
from libs.helper import uuid_value
from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from models.enums import FeedbackFromSource, FeedbackRating
@ -72,8 +59,10 @@ class ChatMessagesQuery(BaseModel):
return uuid_value(value)
class MessageFeedbackPayload(_MessageFeedbackPayloadBase):
class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
@field_validator("message_id")
@classmethod
@ -110,51 +99,6 @@ class SuggestedQuestionsResponse(BaseModel):
data: list[str] = Field(description="Suggested question")
class MessageDetailResponse(ResponseModel):
id: str
conversation_id: str
inputs: dict[str, JSONValue]
query: str
message: JSONValue | None = None
message_tokens: int | None = None
answer: str = Field(validation_alias="re_sign_file_url_answer")
answer_tokens: int | None = None
provider_response_latency: float | None = None
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
feedbacks: list[Feedback] = Field(default_factory=list)
workflow_run_id: str | None = None
annotation: ConversationAnnotation | None = None
annotation_hit_history: ConversationAnnotationHitHistory | None = None
created_at: int | None = None
agent_thoughts: list[AgentThought] = Field(default_factory=list)
message_files: list[MessageFile] = Field(default_factory=list)
extra_contents: list[ExecutionExtraContentDomainModel] = Field(default_factory=list)
metadata: JSONValue | None = Field(default=None, validation_alias="message_metadata_dict")
status: str
error: str | None = None
parent_message_id: str | None = None
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
return format_files_contained(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class MessageInfiniteScrollPaginationResponse(ResponseModel):
limit: int
has_more: bool
data: list[MessageDetailResponse]
register_schema_models(
console_ns,
ChatMessagesQuery,
@ -162,8 +106,124 @@ register_schema_models(
FeedbackExportQuery,
AnnotationCountResponse,
SuggestedQuestionsResponse,
MessageDetailResponse,
MessageInfiniteScrollPaginationResponse,
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
# Base models
simple_account_model = console_ns.model(
"SimpleAccount",
{
"id": fields.String,
"name": fields.String,
"email": fields.String,
},
)
message_file_model = console_ns.model(
"MessageFile",
{
"id": fields.String,
"filename": fields.String,
"type": fields.String,
"url": fields.String,
"mime_type": fields.String,
"size": fields.Integer,
"transfer_method": fields.String,
"belongs_to": fields.String(default="user"),
"upload_file_id": fields.String(default=None),
},
)
agent_thought_model = console_ns.model(
"AgentThought",
{
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
},
)
# Models that depend on simple_account_model
feedback_model = console_ns.model(
"Feedback",
{
"rating": fields.String,
"content": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account": fields.Nested(simple_account_model, allow_null=True),
},
)
annotation_model = console_ns.model(
"Annotation",
{
"id": fields.String,
"question": fields.String,
"content": fields.String,
"account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
annotation_hit_history_model = console_ns.model(
"AnnotationHitHistory",
{
"annotation_id": fields.String(attribute="id"),
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
# Message detail model that depends on multiple models
message_detail_model = console_ns.model(
"MessageDetail",
{
"id": fields.String,
"conversation_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"message": fields.Raw,
"message_tokens": fields.Integer,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"answer_tokens": fields.Integer,
"provider_response_latency": fields.Float,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"feedbacks": fields.List(fields.Nested(feedback_model)),
"workflow_run_id": fields.String,
"annotation": fields.Nested(annotation_model, allow_null=True),
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"message_files": fields.List(fields.Nested(message_file_model)),
"extra_contents": fields.List(fields.Raw),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
},
)
# Message infinite scroll pagination model
message_infinite_scroll_pagination_model = console_ns.model(
"MessageInfiniteScrollPagination",
{
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_detail_model)),
},
)
@ -173,12 +233,13 @@ class ChatMessageListApi(Resource):
@console_ns.doc(description="Get chat messages for a conversation with pagination")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
@console_ns.response(200, "Success", console_ns.models[MessageInfiniteScrollPaginationResponse.__name__])
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
@console_ns.response(404, "Conversation not found")
@login_required
@account_initialization_required
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
@ -238,10 +299,7 @@ class ChatMessageListApi(Resource):
history_messages = list(reversed(history_messages))
attach_message_extra_contents(history_messages)
return MessageInfiniteScrollPaginationResponse.model_validate(
InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more),
from_attributes=True,
).model_dump(mode="json")
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
@ -411,12 +469,13 @@ class MessageApi(Resource):
@console_ns.doc("get_message")
@console_ns.doc(description="Get message details by ID")
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@console_ns.response(200, "Message retrieved successfully", console_ns.models[MessageDetailResponse.__name__])
@console_ns.response(200, "Message retrieved successfully", message_detail_model)
@console_ns.response(404, "Message not found")
@get_app_model
@setup_required
@login_required
@account_initialization_required
@marshal_with(message_detail_model)
def get(self, app_model, message_id: str):
message_id = str(message_id)
@ -428,4 +487,4 @@ class MessageApi(Resource):
raise NotFound("Message Not Exists.")
attach_message_extra_contents([message])
return MessageDetailResponse.model_validate(message, from_attributes=True).model_dump(mode="json")
return message

View File

@ -1,11 +1,9 @@
import json
from typing import Any, cast
from typing import cast
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from flask_restx import Resource, fields
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@ -20,30 +18,30 @@ from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
class ModelConfigRequest(BaseModel):
provider: str | None = Field(default=None, description="Model provider")
model: str | None = Field(default=None, description="Model name")
configs: dict[str, Any] | None = Field(default=None, description="Model configuration parameters")
opening_statement: str | None = Field(default=None, description="Opening statement")
suggested_questions: list[str] | None = Field(default=None, description="Suggested questions")
more_like_this: dict[str, Any] | None = Field(default=None, description="More like this configuration")
speech_to_text: dict[str, Any] | None = Field(default=None, description="Speech to text configuration")
text_to_speech: dict[str, Any] | None = Field(default=None, description="Text to speech configuration")
retrieval_model: dict[str, Any] | None = Field(default=None, description="Retrieval model configuration")
tools: list[dict[str, Any]] | None = Field(default=None, description="Available tools")
dataset_configs: dict[str, Any] | None = Field(default=None, description="Dataset configurations")
agent_mode: dict[str, Any] | None = Field(default=None, description="Agent mode configuration")
register_schema_models(console_ns, ModelConfigRequest)
@console_ns.route("/apps/<uuid:app_id>/model-config")
class ModelConfigResource(Resource):
@console_ns.doc("update_app_model_config")
@console_ns.doc(description="Update application model configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ModelConfigRequest.__name__])
@console_ns.expect(
console_ns.model(
"ModelConfigRequest",
{
"provider": fields.String(description="Model provider"),
"model": fields.String(description="Model name"),
"configs": fields.Raw(description="Model configuration parameters"),
"opening_statement": fields.String(description="Opening statement"),
"suggested_questions": fields.List(fields.String(), description="Suggested questions"),
"more_like_this": fields.Raw(description="More like this configuration"),
"speech_to_text": fields.Raw(description="Speech to text configuration"),
"text_to_speech": fields.Raw(description="Text to speech configuration"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"tools": fields.List(fields.Raw(), description="Available tools"),
"dataset_configs": fields.Raw(description="Dataset configurations"),
"agent_mode": fields.Raw(description="Agent mode configuration"),
},
)
)
@console_ns.response(200, "Model configuration updated successfully")
@console_ns.response(400, "Invalid configuration")
@console_ns.response(404, "App not found")

View File

@ -1,12 +1,11 @@
from typing import Literal
from flask_restx import Resource
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
@ -16,11 +15,13 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import Site
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppSiteUpdatePayload(BaseModel):
title: str | None = Field(default=None)
@ -48,26 +49,13 @@ class AppSiteUpdatePayload(BaseModel):
return supported_language(value)
class AppSiteResponse(ResponseModel):
app_id: str
access_token: str | None = Field(default=None, validation_alias="code")
code: str | None = None
title: str
icon: str | None = None
icon_background: str | None = None
description: str | None = None
default_language: str
customize_domain: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
customize_token_strategy: str
prompt_public: bool
show_workflow_steps: bool
use_icon_as_answer_icon: bool
console_ns.schema_model(
AppSiteUpdatePayload.__name__,
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
register_schema_models(console_ns, AppSiteUpdatePayload, AppSiteResponse)
# Register model for flask_restx to avoid dict type issues in Swagger
app_site_model = console_ns.model("AppSite", app_site_fields)
@console_ns.route("/apps/<uuid:app_id>/site")
@ -76,7 +64,7 @@ class AppSite(Resource):
@console_ns.doc(description="Update application site configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
@console_ns.response(200, "Site configuration updated successfully", console_ns.models[AppSiteResponse.__name__])
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "App not found")
@setup_required
@ -84,6 +72,7 @@ class AppSite(Resource):
@edit_permission_required
@account_initialization_required
@get_app_model
@marshal_with(app_site_model)
def post(self, app_model):
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
@ -117,7 +106,7 @@ class AppSite(Resource):
site.updated_at = naive_utc_now()
db.session.commit()
return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json")
return site
@console_ns.route("/apps/<uuid:app_id>/site/access-token-reset")
@ -125,7 +114,7 @@ class AppSiteAccessTokenReset(Resource):
@console_ns.doc("reset_app_site_access_token")
@console_ns.doc(description="Reset access token for application site")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Access token reset successfully", console_ns.models[AppSiteResponse.__name__])
@console_ns.response(200, "Access token reset successfully", app_site_model)
@console_ns.response(403, "Insufficient permissions (admin/owner required)")
@console_ns.response(404, "App or site not found")
@setup_required
@ -133,6 +122,7 @@ class AppSiteAccessTokenReset(Resource):
@is_admin_or_owner_required
@account_initialization_required
@get_app_model
@marshal_with(app_site_model)
def post(self, app_model):
current_user, _ = current_account_with_tenant()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
@ -145,4 +135,4 @@ class AppSiteAccessTokenReset(Resource):
site.updated_at = naive_utc_now()
db.session.commit()
return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json")
return site

View File

@ -4,7 +4,7 @@ from collections.abc import Sequence
from typing import Any
from flask import abort, request
from flask_restx import Resource, fields, marshal, marshal_with
from flask_restx import Resource, fields, marshal_with
from graphon.enums import NodeType
from graphon.file import File
from graphon.graph_engine.manager import GraphEngineManager
@ -14,7 +14,6 @@ from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
from controllers.console import console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.workflow_run import workflow_run_node_execution_model
@ -143,6 +142,10 @@ class PublishWorkflowPayload(BaseModel):
marked_comment: str | None = Field(default=None, max_length=100)
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class ConvertToWorkflowPayload(BaseModel):
name: str | None = None
icon_type: str | None = None
@ -150,6 +153,18 @@ class ConvertToWorkflowPayload(BaseModel):
icon_background: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DraftWorkflowTriggerRunPayload(BaseModel):
node_id: str
@ -942,6 +957,7 @@ class PublishedAllWorkflowApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_pagination_model)
@edit_permission_required
def get(self, app_model: App):
"""
@ -969,10 +985,9 @@ class PublishedAllWorkflowApi(Resource):
user_id=user_id,
named_only=named_only,
)
serialized_workflows = marshal(workflows, workflow_fields_copy)
return {
"items": serialized_workflows,
"items": workflows,
"page": page,
"limit": limit,
"has_more": has_more,

View File

@ -1,26 +1,27 @@
from datetime import datetime
from typing import Any
from dateutil.parser import isoparse
from flask import request
from flask_restx import Resource
from flask_restx import Resource, marshal_with
from graphon.enums import WorkflowExecutionStatus
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.end_user_fields import SimpleEndUser
from fields.member_fields import SimpleAccount
from fields.workflow_app_log_fields import (
build_workflow_app_log_pagination_model,
build_workflow_archived_log_pagination_model,
)
from libs.login import login_required
from models import App
from models.model import AppMode
from services.workflow_app_service import WorkflowAppService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowAppLogQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
@ -57,114 +58,14 @@ class WorkflowAppLogQuery(BaseModel):
raise ValueError("Invalid boolean value for detail")
class WorkflowRunForLogResponse(ResponseModel):
id: str
version: str | None = None
status: str | None = None
triggered_from: str | None = None
error: str | None = None
elapsed_time: float | None = None
total_tokens: int | None = None
total_steps: int | None = None
created_at: int | None = None
finished_at: int | None = None
exceptions_count: int | None = None
@field_validator("status", mode="before")
@classmethod
def _normalize_status(cls, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@field_validator("created_at", "finished_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class WorkflowRunForArchivedLogResponse(ResponseModel):
id: str
status: str | None = None
triggered_from: str | None = None
elapsed_time: float | None = None
total_tokens: int | None = None
@field_validator("status", mode="before")
@classmethod
def _normalize_status(cls, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
class WorkflowAppLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForLogResponse | None = None
details: Any = None
created_from: str | None = None
created_by_role: str | None = None
created_by_account: SimpleAccount | None = None
created_by_end_user: SimpleEndUser | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class WorkflowArchivedLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForArchivedLogResponse | None = None
trigger_metadata: Any = None
created_by_account: SimpleAccount | None = None
created_by_end_user: SimpleEndUser | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class WorkflowAppLogPaginationResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[WorkflowAppLogPartialResponse]
class WorkflowArchivedLogPaginationResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[WorkflowArchivedLogPartialResponse]
register_schema_models(
console_ns,
WorkflowAppLogQuery,
WorkflowRunForLogResponse,
WorkflowRunForArchivedLogResponse,
WorkflowAppLogPartialResponse,
WorkflowArchivedLogPartialResponse,
WorkflowAppLogPaginationResponse,
WorkflowArchivedLogPaginationResponse,
console_ns.schema_model(
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
# Register model for flask_restx to avoid dict type issues in Swagger
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns)
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
class WorkflowAppLogApi(Resource):
@ -172,15 +73,12 @@ class WorkflowAppLogApi(Resource):
@console_ns.doc(description="Get workflow application execution logs")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
@console_ns.response(
200,
"Workflow app logs retrieved successfully",
console_ns.models[WorkflowAppLogPaginationResponse.__name__],
)
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@marshal_with(workflow_app_log_pagination_model)
def get(self, app_model: App):
"""
Get workflow app logs
@ -189,7 +87,7 @@ class WorkflowAppLogApi(Resource):
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with sessionmaker(db.engine).begin() as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
@ -204,9 +102,7 @@ class WorkflowAppLogApi(Resource):
created_by_account=args.created_by_account,
)
return WorkflowAppLogPaginationResponse.model_validate(
workflow_app_log_pagination, from_attributes=True
).model_dump(mode="json")
return workflow_app_log_pagination
@console_ns.route("/apps/<uuid:app_id>/workflow-archived-logs")
@ -215,15 +111,12 @@ class WorkflowArchivedLogApi(Resource):
@console_ns.doc(description="Get workflow archived execution logs")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
@console_ns.response(
200,
"Workflow archived logs retrieved successfully",
console_ns.models[WorkflowArchivedLogPaginationResponse.__name__],
)
@console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@marshal_with(workflow_archived_log_pagination_model)
def get(self, app_model: App):
"""
Get workflow archived logs
@ -231,7 +124,7 @@ class WorkflowArchivedLogApi(Resource):
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with sessionmaker(db.engine).begin() as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
session=session,
app_model=app_model,
@ -239,6 +132,4 @@ class WorkflowArchivedLogApi(Resource):
limit=args.limit,
)
return WorkflowArchivedLogPaginationResponse.model_validate(
workflow_app_log_pagination, from_attributes=True
).model_dump(mode="json")
return workflow_app_log_pagination

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import Any, TypedDict
from typing import Any, NoReturn, ParamSpec, TypeVar
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
@ -86,14 +86,7 @@ def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
return value_type.exposed_type().value
class FullContentDict(TypedDict):
size_bytes: int | None
value_type: str
length: int | None
download_url: str
def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict | None:
def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
"""Serialize full_content information for large variables."""
if not variable.is_truncated():
return None
@ -101,13 +94,12 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
variable_file = variable.variable_file
assert variable_file is not None
result: FullContentDict = {
return {
"size_bytes": variable_file.size,
"value_type": variable_file.value_type.exposed_type().value,
"length": variable_file.length,
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
}
return result
def _ensure_variable_access(
@ -200,8 +192,11 @@ workflow_draft_variable_list_model = console_ns.model(
"WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
)
P = ParamSpec("P")
R = TypeVar("R")
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
def _api_prerequisite(f: Callable[P, R]):
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -218,7 +213,7 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
def wrapper(*args: P.args, **kwargs: P.kwargs):
return f(*args, **kwargs)
return wrapper
@ -275,7 +270,7 @@ class WorkflowVariableCollectionApi(Resource):
return Response("", 204)
def validate_node_id(node_id: str) -> None:
def validate_node_id(node_id: str) -> NoReturn | None:
if node_id in [
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
@ -290,6 +285,7 @@ def validate_node_id(node_id: str) -> None:
raise InvalidArgumentError(
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
)
return None
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
@ -392,27 +388,24 @@ class VariableApi(Resource):
new_value = None
if raw_value is not None:
match variable.value_type:
case SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
case SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
case _:
pass
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()

View File

@ -36,7 +36,7 @@ from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowR
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
from services.workflow_run_service import WorkflowRunListArgs, WorkflowRunService
from services.workflow_run_service import WorkflowRunService
def _build_backstage_input_url(form_token: str | None) -> str | None:
@ -214,11 +214,7 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
Get advanced chat app workflow run list
"""
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args: WorkflowRunListArgs = {"limit": args_model.limit}
if args_model.last_id is not None:
args["last_id"] = args_model.last_id
if args_model.status is not None:
args["status"] = args_model.status
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING if not specified
triggered_from = (
@ -360,11 +356,7 @@ class WorkflowRunListApi(Resource):
Get workflow run list
"""
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args: WorkflowRunListArgs = {"limit": args_model.limit}
if args_model.last_id is not None:
args["last_id"] = args_model.last_id
if args_model.status is not None:
args["status"] = args_model.status
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (

View File

@ -1,17 +1,16 @@
import logging
from datetime import datetime
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, field_validator
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.common.schema import get_or_create_model
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required
from models.enums import AppTriggerStatus
from models.model import Account, App, AppMode
@ -22,6 +21,15 @@ from ..app.wraps import get_app_model
from ..wraps import account_initialization_required, edit_permission_required, setup_required
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields)
triggers_list_fields_copy = triggers_list_fields.copy()
triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model))
triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy)
webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields)
class Parser(BaseModel):
@ -33,52 +41,10 @@ class ParserEnable(BaseModel):
enable_trigger: bool
class WorkflowTriggerResponse(ResponseModel):
id: str
trigger_type: str
title: str
node_id: str
provider_name: str
icon: str
status: str
created_at: datetime | None = None
updated_at: datetime | None = None
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@field_validator("id", "trigger_type", "title", "node_id", "provider_name", "icon", "status", mode="before")
@classmethod
def _normalize_string_fields(cls, value: object) -> str:
if isinstance(value, str):
return value
return str(value)
class WorkflowTriggerListResponse(ResponseModel):
data: list[WorkflowTriggerResponse]
class WebhookTriggerResponse(ResponseModel):
id: str
webhook_id: str
webhook_url: str
webhook_debug_url: str
node_id: str
created_at: datetime | None = None
@field_validator("id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", mode="before")
@classmethod
def _normalize_string_fields(cls, value: object) -> str:
if isinstance(value, str):
return value
return str(value)
register_schema_models(
console_ns,
Parser,
ParserEnable,
WorkflowTriggerResponse,
WorkflowTriggerListResponse,
WebhookTriggerResponse,
console_ns.schema_model(
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@ -91,28 +57,28 @@ class WebhookTriggerApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__])
@marshal_with(webhook_trigger_model)
def get(self, app_model: App):
"""Get webhook trigger for a node"""
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
node_id = args.node_id
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with sessionmaker(db.engine).begin() as session:
# Get webhook trigger for this app and node
webhook_trigger = session.scalar(
select(WorkflowWebhookTrigger)
webhook_trigger = (
session.query(WorkflowWebhookTrigger)
.where(
WorkflowWebhookTrigger.app_id == app_model.id,
WorkflowWebhookTrigger.node_id == node_id,
)
.limit(1)
.first()
)
if not webhook_trigger:
raise NotFound("Webhook trigger not found for this node")
return WebhookTriggerResponse.model_validate(webhook_trigger, from_attributes=True).model_dump(mode="json")
return webhook_trigger
@console_ns.route("/apps/<uuid:app_id>/triggers")
@ -123,13 +89,13 @@ class AppTriggersApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerListResponse.__name__])
@marshal_with(triggers_list_model)
def get(self, app_model: App):
"""Get app triggers list"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with sessionmaker(db.engine).begin() as session:
# Get all triggers for this app using select API
triggers = (
session.execute(
@ -152,9 +118,7 @@ class AppTriggersApi(Resource):
else:
trigger.icon = "" # type: ignore
return WorkflowTriggerListResponse.model_validate({"data": triggers}, from_attributes=True).model_dump(
mode="json"
)
return {"data": triggers}
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
@ -165,7 +129,7 @@ class AppTriggerEnableApi(Resource):
@account_initialization_required
@edit_permission_required
@get_app_model(mode=AppMode.WORKFLOW)
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerResponse.__name__])
@marshal_with(trigger_model)
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
args = ParserEnable.model_validate(console_ns.payload)
@ -196,4 +160,4 @@ class AppTriggerEnableApi(Resource):
else:
trigger.icon = "" # type: ignore
return WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump(mode="json")
return trigger

View File

@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import overload
from typing import ParamSpec, TypeVar, Union
from sqlalchemy import select
@ -9,6 +9,11 @@ from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models import App, AppMode
P = ParamSpec("P")
R = TypeVar("R")
P1 = ParamSpec("P1")
R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None:
_, current_tenant_id = current_account_with_tenant()
@ -23,30 +28,10 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
return app_model
@overload
def get_app_model[**P, R](
view: Callable[P, R],
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R]: ...
@overload
def get_app_model[**P, R](
view: None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def get_app_model[**P, R](
view: Callable[P, R] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P1, R1]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")
@ -84,30 +69,10 @@ def get_app_model[**P, R](
return decorator(view)
@overload
def get_app_model_with_trial[**P, R](
view: Callable[P, R],
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R]: ...
@overload
def get_app_model_with_trial[**P, R](
view: None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def get_app_model_with_trial[**P, R](
view: Callable[P, R] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")

View File

@ -1,11 +1,8 @@
from typing import Any
from flask import request
from flask_restx import Resource
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
@ -14,6 +11,8 @@ from libs.helper import EmailStr, timezone
from models import AccountStatus
from services.account_service import RegisterService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ActivateCheckQuery(BaseModel):
workspace_id: str | None = Field(default=None)
@ -40,16 +39,8 @@ class ActivatePayload(BaseModel):
return timezone(value)
class ActivationCheckResponse(BaseModel):
is_valid: bool = Field(description="Whether token is valid")
data: dict[str, Any] | None = Field(default=None, description="Activation data if valid")
class ActivationResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse)
for model in (ActivateCheckQuery, ActivatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/activate/check")
@ -60,7 +51,13 @@ class ActivateCheckApi(Resource):
@console_ns.response(
200,
"Success",
console_ns.models[ActivationCheckResponse.__name__],
console_ns.model(
"ActivationCheckResponse",
{
"is_valid": fields.Boolean(description="Whether token is valid"),
"data": fields.Raw(description="Activation data if valid"),
},
),
)
def get(self):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
@ -98,7 +95,12 @@ class ActivateApi(Resource):
@console_ns.response(
200,
"Account activated successfully",
console_ns.models[ActivationResponse.__name__],
console_ns.model(
"ActivationResponse",
{
"result": fields.String(description="Operation result"),
},
),
)
@console_ns.response(400, "Already activated or invalid token")
def post(self):

View File

@ -1,6 +1,7 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from constants.languages import languages
@ -13,6 +14,7 @@ from controllers.console.auth.error import (
InvalidTokenError,
PasswordMismatchError,
)
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password
from models import Account
@ -71,7 +73,8 @@ class EmailRegisterSendEmailApi(Resource):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
account = AccountService.get_account_by_email_with_case_fallback(args.email)
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
return {"result": "success", "data": token}
@ -142,16 +145,17 @@ class EmailRegisterResetApi(Resource):
email = register_data.get("email", "")
normalized_email = email.lower()
account = AccountService.get_account_by_email_with_case_fallback(email)
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()}

View File

@ -3,7 +3,8 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
@ -19,18 +20,35 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password
from libs.password import hash_password, valid_password
from services.account_service import AccountService, TenantService
from services.entities.auth_entities import (
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordSendPayload,
)
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
class ForgotPasswordEmailResponse(BaseModel):
result: str = Field(description="Operation result")
data: str | None = Field(default=None, description="Reset token")
@ -84,7 +102,8 @@ class ForgotPasswordSendEmailApi(Resource):
else:
language = "en-US"
account = AccountService.get_account_by_email_with_case_fallback(args.email)
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_reset_password_email(
account=account,
@ -182,18 +201,17 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
account = AccountService.get_account_by_email_with_case_fallback(email)
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
account = db.session.merge(account)
self._update_existing_account(account, password_hashed, salt)
db.session.commit()
else:
raise AccountNotFound()
if account:
self._update_existing_account(account, password_hashed, salt, session)
else:
raise AccountNotFound()
return {"result": "success"}
def _update_existing_account(self, account, password_hashed, salt):
def _update_existing_account(self, account, password_hashed, salt, session):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()

View File

@ -1,10 +1,9 @@
import logging
from typing import Any
import flask_login
from flask import make_response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Unauthorized
import services
from configs import dify_config
@ -43,18 +42,18 @@ from libs.token import (
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
logger = logging.getLogger(__name__)
class LoginPayload(LoginPayloadBase):
class LoginPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
password: str = Field(..., description="Password")
remember_me: bool = Field(default=False, description="Remember me flag")
invite_token: str | None = Field(default=None, description="Invitation token")
@ -95,16 +94,14 @@ class LoginApi(Resource):
normalized_email = request_email.lower()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
if is_login_error_rate_limit:
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.LOGIN_RATE_LIMITED)
raise EmailPasswordLoginLimitError()
invite_token = args.invite_token
invitation_data: InvitationDetailDict | None = None
invitation_data: dict[str, Any] | None = None
if invite_token:
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
if invitation_data is None:
@ -116,20 +113,14 @@ class LoginApi(Resource):
invitee_email = data.get("email") if data else None
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
if invitee_email_normalized != normalized_email:
_log_console_login_failure(
email=normalized_email,
reason=LoginFailureReason.INVALID_INVITATION_EMAIL,
)
raise InvalidEmailError()
account = _authenticate_account_with_case_fallback(
request_email, normalized_email, args.password, invite_token
)
except services.errors.account.AccountLoginError:
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError()
except services.errors.account.AccountPasswordError as exc:
AccountService.add_login_error_rate_limit(normalized_email)
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
@ -252,27 +243,20 @@ class EmailCodeLoginApi(Resource):
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
raise InvalidTokenError()
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != user_email:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError()
if token_data["code"] != args.code:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
raise EmailCodeError()
AccountService.revoke_email_code_login_token(args.token)
try:
account = _get_account_with_case_fallback(original_email)
except Unauthorized as exc:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError() from exc
except AccountRegisterError:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
if account:
tenants = TenantService.get_join_tenants(account)
@ -298,7 +282,6 @@ class EmailCodeLoginApi(Resource):
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
except AccountRegisterError:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
@ -356,12 +339,3 @@ def _authenticate_account_with_case_fallback(
if original_email == normalized_email:
raise
return AccountService.authenticate(normalized_email, password, invite_token)
def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None:
logger.warning(
"Console login failed: email=%s reason=%s ip_address=%s",
email,
reason,
extract_remote_ip(request),
)

View File

@ -4,6 +4,7 @@ import urllib.parse
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Unauthorized
from configs import dify_config
@ -179,7 +180,8 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
account: Account | None = Account.get_by_openid(provider, user_info.id)
if not account:
account = AccountService.get_account_by_email_with_case_fallback(user_info.email)
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
return account

View File

@ -1,9 +1,8 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate
from typing import Concatenate, ParamSpec, TypeVar
from flask import jsonify, request
from flask.typing import ResponseReturnValue
from flask_restx import Resource
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel
@ -17,6 +16,10 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
from .. import console_ns
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
class OAuthClientPayload(BaseModel):
client_id: str
@ -36,11 +39,9 @@ class OAuthTokenRequest(BaseModel):
refresh_token: str | None = None
def oauth_server_client_id_required[T, **P, R](
view: Callable[Concatenate[T, OAuthProviderApp, P], R],
) -> Callable[Concatenate[T, P], R]:
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
json_data = request.get_json()
if json_data is None:
raise BadRequest("client_id is required")
@ -57,13 +58,9 @@ def oauth_server_client_id_required[T, **P, R](
return decorated
def oauth_server_access_token_required[T, **P, R](
view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R],
) -> Callable[Concatenate[T, OAuthProviderApp, P], R | ResponseReturnValue]:
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
@wraps(view)
def decorated(
self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs
) -> R | ResponseReturnValue:
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
if not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")

View File

@ -2,17 +2,18 @@ import base64
from typing import Literal
from flask import request
from flask_restx import Resource
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SubscriptionQuery(BaseModel):
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
@ -23,7 +24,8 @@ class PartnerTenantsPayload(BaseModel):
click_id: str = Field(..., description="Click Id from partner referral link")
register_schema_models(console_ns, SubscriptionQuery, PartnerTenantsPayload)
for model in (SubscriptionQuery, PartnerTenantsPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/billing/subscription")
@ -56,7 +58,12 @@ class PartnerTenants(Resource):
@console_ns.doc("sync_partner_tenants_bindings")
@console_ns.doc(description="Sync partner tenants bindings")
@console_ns.doc(params={"partner_key": "Partner key"})
@console_ns.expect(console_ns.models[PartnerTenantsPayload.__name__])
@console_ns.expect(
console_ns.model(
"SyncPartnerTenantsBindingsRequest",
{"click_id": fields.String(required=True, description="Click Id from partner referral link")},
)
)
@console_ns.response(200, "Tenants synced to partner successfully")
@console_ns.response(400, "Invalid partner information")
@setup_required

View File

@ -158,13 +158,10 @@ class DataSourceApi(Resource):
@login_required
@account_initialization_required
def patch(self, binding_id, action: Literal["enable", "disable"]):
_, current_tenant_id = current_account_with_tenant()
binding_id = str(binding_id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.tenant_id == current_tenant_id
)
select(DataSourceOauthBinding).filter_by(id=binding_id)
).scalar_one_or_none()
if data_source_binding is None:
raise NotFound("Data source binding not found.")
@ -224,11 +221,11 @@ class DataSourceNotionListApi(Resource):
raise ValueError("Dataset is not notion type.")
documents = session.scalars(
select(Document).where(
Document.dataset_id == query.dataset_id,
Document.tenant_id == current_tenant_id,
Document.data_source_type == "notion_import",
Document.enabled.is_(True),
select(Document).filter_by(
dataset_id=query.dataset_id,
tenant_id=current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
).all()
if documents:

View File

@ -11,7 +11,10 @@ import services
from configs import dify_config
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import ApiKeyItem, ApiKeyList
from controllers.console.apikey import (
api_key_item_model,
api_key_list_model,
)
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.wraps import (
@ -782,23 +785,23 @@ class DatasetApiKeyApi(Resource):
@console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get dataset API keys")
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
@console_ns.response(200, "API keys retrieved successfully", api_key_list_model)
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_list_model)
def get(self):
_, current_tenant_id = current_account_with_tenant()
keys = db.session.scalars(
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
).all()
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
return {"items": keys}
@console_ns.response(200, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
@marshal_with(api_key_item_model)
def post(self):
_, current_tenant_id = current_account_with_tenant()
@ -825,7 +828,7 @@ class DatasetApiKeyApi(Resource):
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 200
return api_token, 200
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")

View File

@ -4,6 +4,7 @@ from argparse import ArgumentTypeError
from collections.abc import Sequence
from contextlib import ExitStack
from typing import Any, Literal, cast
from uuid import UUID
import sqlalchemy as sa
from flask import request, send_file
@ -15,7 +16,6 @@ from sqlalchemy import asc, desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from core.errors.error import (
@ -71,6 +71,9 @@ from ..wraps import (
logger = logging.getLogger(__name__)
# NOTE: Keep constants near the top of the module for discoverability.
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_model = get_or_create_model("Dataset", dataset_fields)
@ -107,6 +110,12 @@ class GenerateSummaryPayload(BaseModel):
document_list: list[str]
class DocumentBatchDownloadZipPayload(BaseModel):
"""Request payload for bulk downloading documents as a zip archive."""
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
class DocumentDatasetListParam(BaseModel):
page: int = Field(1, title="Page", description="Page number.")
limit: int = Field(20, title="Limit", description="Page size.")
@ -271,7 +280,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
query = select(Document).where(Document.dataset_id == str(dataset_id), Document.tenant_id == current_tenant_id)
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
if status:
query = DocumentService.apply_display_status_filter(query, status)
@ -1026,7 +1035,7 @@ class DocumentMetadataApi(DocumentResource):
if not isinstance(doc_metadata, dict):
raise ValueError("doc_metadata must be a dictionary.")
metadata_schema: dict[str, Any] = cast(dict[str, Any], DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
document.doc_metadata = {}
if doc_type == "others":

View File

@ -10,7 +10,6 @@ from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
@ -83,6 +82,14 @@ class BatchImportPayload(BaseModel):
upload_file_id: str
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkUpdatePayload(BaseModel):
content: str
class ChildChunkBatchUpdatePayload(BaseModel):
chunks: list[ChildChunkUpdateArgs]

View File

@ -173,11 +173,8 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id):
_, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
external_knowledge_api_id, current_tenant_id
)
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id)
if external_knowledge_api is None:
raise NotFound("API template not found.")
@ -227,11 +224,10 @@ class ExternalApiUseCheckApi(Resource):
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id):
_, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
external_knowledge_api_id, current_tenant_id
external_knowledge_api_id
)
return {"is_using": external_knowledge_api_is_using, "count": count}, 200

View File

@ -1,13 +1,13 @@
from __future__ import annotations
from flask_restx import Resource, fields
from datetime import datetime
from typing import Any
from flask_restx import Resource
from pydantic import Field, field_validator
from controllers.common.schema import register_schema_models
from fields.base import ResponseModel
from controllers.common.schema import register_schema_model
from fields.hit_testing_fields import (
child_chunk_fields,
document_fields,
files_fields,
hit_testing_record_fields,
segment_fields,
)
from libs.login import login_required
from .. import console_ns
@ -18,92 +18,39 @@ from ..wraps import (
setup_required,
)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
register_schema_model(console_ns, HitTestingPayload)
class HitTestingDocument(ResponseModel):
id: str | None = None
data_source_type: str | None = None
name: str | None = None
doc_type: str | None = None
doc_metadata: Any | None = None
def _get_or_create_model(model_name: str, field_def):
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
class HitTestingSegment(ResponseModel):
id: str | None = None
position: int | None = None
document_id: str | None = None
content: str | None = None
sign_content: str | None = None
answer: str | None = None
word_count: int | None = None
tokens: int | None = None
keywords: list[str] = Field(default_factory=list)
index_node_id: str | None = None
index_node_hash: str | None = None
hit_count: int | None = None
enabled: bool | None = None
disabled_at: int | None = None
disabled_by: str | None = None
status: str | None = None
created_by: str | None = None
created_at: int | None = None
indexing_at: int | None = None
completed_at: int | None = None
error: str | None = None
stopped_at: int | None = None
document: HitTestingDocument | None = None
# Register models for flask_restx to avoid dict type issues in Swagger
document_model = _get_or_create_model("HitTestingDocument", document_fields)
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
segment_fields_copy = segment_fields.copy()
segment_fields_copy["document"] = fields.Nested(document_model)
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
files_model = _get_or_create_model("HitTestingFile", files_fields)
class HitTestingChildChunk(ResponseModel):
id: str | None = None
content: str | None = None
position: int | None = None
score: float | None = None
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
class HitTestingFile(ResponseModel):
id: str | None = None
name: str | None = None
size: int | None = None
extension: str | None = None
mime_type: str | None = None
source_url: str | None = None
class HitTestingRecord(ResponseModel):
segment: HitTestingSegment | None = None
child_chunks: list[HitTestingChildChunk] = Field(default_factory=list)
score: float | None = None
tsne_position: Any | None = None
files: list[HitTestingFile] = Field(default_factory=list)
summary: str | None = None
class HitTestingResponse(ResponseModel):
query: str
records: list[HitTestingRecord] = Field(default_factory=list)
register_schema_models(
console_ns,
HitTestingPayload,
HitTestingDocument,
HitTestingSegment,
HitTestingChildChunk,
HitTestingFile,
HitTestingRecord,
HitTestingResponse,
)
# Response model for hit testing API
hit_testing_response_fields = {
"query": fields.String,
"records": fields.List(fields.Nested(hit_testing_record_model)),
}
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
@ -112,11 +59,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
@console_ns.response(
200,
"Hit testing completed successfully",
model=console_ns.models[HitTestingResponse.__name__],
)
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@setup_required
@ -131,4 +74,4 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
args = payload.model_dump(exclude_none=True)
self.hit_testing_args_check(args)
return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json")
return self.perform_hit_testing(dataset, args)

View File

@ -1,9 +1,9 @@
from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import MetadataUpdatePayload
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
@ -18,6 +18,11 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.metadata_service import MetadataService
class MetadataUpdatePayload(BaseModel):
name: str
register_schema_models(
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
)

View File

@ -3,7 +3,6 @@ import logging
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
@ -87,8 +86,8 @@ class CustomizedPipelineTemplateApi(Resource):
@enterprise_license_required
def post(self, template_id: str):
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
template = session.scalar(
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).limit(1)
template = (
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
)
if not template:
raise ValueError("Customized pipeline template not found.")

View File

@ -1,5 +1,4 @@
import logging
from collections.abc import Callable
from typing import Any, NoReturn
from flask import Response, request
@ -56,7 +55,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
def _api_prerequisite(f):
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -71,7 +70,7 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
@login_required
@account_initialization_required
@get_rag_pipeline
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
def wrapper(*args, **kwargs):
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)
@ -223,27 +222,24 @@ class RagPipelineVariableApi(Resource):
new_value = None
if raw_value is not None:
match variable.value_type:
case SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
case SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
case _:
pass
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()

View File

@ -19,7 +19,7 @@ from fields.rag_pipeline_fields import (
)
from libs.login import current_account_with_tenant, login_required
from models.dataset import Pipeline
from services.entities.dsl_entities import ImportStatus
from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@ -83,13 +83,11 @@ class RagPipelineImportApi(Resource):
# Return appropriate status code based on result
status = result.status
match status:
case ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
case ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
return result.model_dump(mode="json"), 200
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")

View File

@ -10,7 +10,6 @@ from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
@ -95,6 +94,22 @@ class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
original_document_id: str | None = None
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class NodeIdQuery(BaseModel):
node_id: str
@ -346,6 +361,89 @@ class PublishedRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description)
# class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource):
# @setup_required
# @login_required
# @account_initialization_required
# @get_rag_pipeline
# def post(self, pipeline: Pipeline, node_id: str):
# """
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
# raise Forbidden()
#
# parser = (reqparse.RequestParser()
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args()
#
# job_id = args.get("job_id")
# if job_id == None:
# raise ValueError("missing job_id")
# datasource_type = args.get("datasource_type")
# if datasource_type == None:
# raise ValueError("missing datasource_type")
#
# rag_pipeline_service = RagPipelineService()
# result = rag_pipeline_service.run_datasource_workflow_node_status(
# pipeline=pipeline,
# node_id=node_id,
# job_id=job_id,
# account=current_user,
# datasource_type=datasource_type,
# is_published=True
# )
#
# return result
# class RagPipelineDraftDatasourceNodeRunStatusApi(Resource):
# @setup_required
# @login_required
# @account_initialization_required
# @get_rag_pipeline
# def post(self, pipeline: Pipeline, node_id: str):
# """
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
# raise Forbidden()
#
# parser = (reqparse.RequestParser()
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args()
#
# job_id = args.get("job_id")
# if job_id == None:
# raise ValueError("missing job_id")
# datasource_type = args.get("datasource_type")
# if datasource_type == None:
# raise ValueError("missing datasource_type")
#
# rag_pipeline_service = RagPipelineService()
# result = rag_pipeline_service.run_datasource_workflow_node_status(
# pipeline=pipeline,
# node_id=node_id,
# job_id=job_id,
# account=current_user,
# datasource_type=datasource_type,
# is_published=False
# )
#
# return result
#
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])

View File

@ -1,5 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from sqlalchemy import select
@ -8,10 +9,13 @@ from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.dataset import Pipeline
P = ParamSpec("P")
R = TypeVar("R")
def get_rag_pipeline[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
def get_rag_pipeline(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters")

View File

@ -2,10 +2,10 @@ import logging
from flask import request
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.controller_schemas import TextToAudioPayload
from controllers.common.schema import register_schema_model
from controllers.console.app.error import (
AppUnavailableError,
@ -32,6 +32,14 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(console_ns, TextToAudioPayload)

View File

@ -1,11 +1,10 @@
from typing import Any
from flask import request
from pydantic import BaseModel, Field, TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter, model_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
@ -33,6 +32,18 @@ class ConversationListQuery(BaseModel):
pinned: bool | None = None
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)

View File

@ -1,24 +1,21 @@
import logging
from datetime import datetime
from typing import Any
from flask import request
from flask_restx import Resource
from graphon.file import helpers as file_helpers
from pydantic import BaseModel, Field, computed_field, field_validator
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.common.schema import register_schema_models
from controllers.common.schema import get_or_create_model
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
from models.model import IconType
from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -39,97 +36,22 @@ class InstalledAppsListQuery(BaseModel):
logger = logging.getLogger(__name__)
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
if icon is None or icon_type is None:
return None
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
if icon_type_value.lower() != IconType.IMAGE:
return None
return file_helpers.get_signed_file_url(icon)
app_model = get_or_create_model("InstalledAppInfo", app_fields)
installed_app_fields_copy = installed_app_fields.copy()
installed_app_fields_copy["app"] = fields.Nested(app_model)
installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy)
def _safe_primitive(value: Any) -> Any:
if value is None or isinstance(value, (str, int, float, bool, datetime)):
return value
return None
class InstalledAppInfoResponse(ResponseModel):
id: str
name: str | None = None
mode: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
use_icon_as_answer_icon: bool | None = None
@field_validator("mode", "icon_type", mode="before")
@classmethod
def _normalize_enum_like(cls, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
class InstalledAppResponse(ResponseModel):
id: str
app: InstalledAppInfoResponse
app_owner_tenant_id: str
is_pinned: bool
last_used_at: int | None = None
editable: bool
uninstallable: bool
@field_validator("app", mode="before")
@classmethod
def _normalize_app(cls, value: Any) -> Any:
if isinstance(value, dict):
return value
return {
"id": _safe_primitive(getattr(value, "id", "")) or "",
"name": _safe_primitive(getattr(value, "name", None)),
"mode": _safe_primitive(getattr(value, "mode", None)),
"icon_type": _safe_primitive(getattr(value, "icon_type", None)),
"icon": _safe_primitive(getattr(value, "icon", None)),
"icon_background": _safe_primitive(getattr(value, "icon_background", None)),
"use_icon_as_answer_icon": _safe_primitive(getattr(value, "use_icon_as_answer_icon", None)),
}
@field_validator("last_used_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class InstalledAppListResponse(ResponseModel):
installed_apps: list[InstalledAppResponse]
register_schema_models(
console_ns,
InstalledAppCreatePayload,
InstalledAppUpdatePayload,
InstalledAppsListQuery,
InstalledAppInfoResponse,
InstalledAppResponse,
InstalledAppListResponse,
)
installed_app_list_fields_copy = installed_app_list_fields.copy()
installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model))
installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy)
@console_ns.route("/installed-apps")
class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
@marshal_with(installed_app_list_model)
def get(self):
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
current_user, current_tenant_id = current_account_with_tenant()
@ -203,9 +125,7 @@ class InstalledAppsListApi(Resource):
)
)
return InstalledAppListResponse.model_validate(
{"installed_apps": installed_app_list}, from_attributes=True
).model_dump(mode="json")
return {"installed_apps": installed_app_list}
@login_required
@account_initialization_required

View File

@ -3,10 +3,9 @@ from typing import Literal
from flask import request
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
from controllers.common.schema import register_schema_models
from controllers.console.app.error import (
AppMoreLikeThisDisabledError,
@ -26,6 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from models.enums import FeedbackRating
from models.model import AppMode
@ -44,6 +44,17 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
class MoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"]

View File

@ -1,83 +1,66 @@
from typing import Any
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, computed_field, field_validator
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from constants.languages import languages
from controllers.common.schema import register_schema_models
from controllers.common.schema import get_or_create_model
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required
from fields.base import ResponseModel
from libs.helper import build_icon_url
from libs.helper import AppIconUrlField
from libs.login import current_user, login_required
from services.recommended_app_service import RecommendedAppService
app_fields = {
"id": fields.String,
"name": fields.String,
"mode": fields.String,
"icon": fields.String,
"icon_type": fields.String,
"icon_url": AppIconUrlField,
"icon_background": fields.String,
}
app_model = get_or_create_model("RecommendedAppInfo", app_fields)
recommended_app_fields = {
"app": fields.Nested(app_model, attribute="app"),
"app_id": fields.String,
"description": fields.String(attribute="description"),
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
"can_trial": fields.Boolean,
}
recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields)
recommended_app_list_fields = {
"recommended_apps": fields.List(fields.Nested(recommended_app_model)),
"categories": fields.List(fields.String),
}
recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields)
class RecommendedAppsQuery(BaseModel):
language: str | None = Field(default=None)
class RecommendedAppInfoResponse(ResponseModel):
id: str
name: str | None = None
mode: str | None = None
icon: str | None = None
icon_type: str | None = None
icon_background: str | None = None
@staticmethod
def _normalize_enum_like(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@field_validator("mode", "icon_type", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> str | None:
return cls._normalize_enum_like(value)
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
@property
def icon_url(self) -> str | None:
return build_icon_url(self.icon_type, self.icon)
class RecommendedAppResponse(ResponseModel):
app: RecommendedAppInfoResponse | None = None
app_id: str
description: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
category: str | None = None
position: int | None = None
is_listed: bool | None = None
can_trial: bool | None = None
class RecommendedAppListResponse(ResponseModel):
recommended_apps: list[RecommendedAppResponse]
categories: list[str]
register_schema_models(
console_ns,
RecommendedAppsQuery,
RecommendedAppInfoResponse,
RecommendedAppResponse,
RecommendedAppListResponse,
console_ns.schema_model(
RecommendedAppsQuery.__name__,
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource):
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
@console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__])
@login_required
@account_initialization_required
@marshal_with(recommended_app_list_model)
def get(self):
# language args
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
@ -89,10 +72,7 @@ class RecommendedAppListApi(Resource):
else:
language_prefix = languages[0]
return RecommendedAppListResponse.model_validate(
RecommendedAppService.get_recommended_apps_and_categories(language_prefix),
from_attributes=True,
).model_dump(mode="json")
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
@console_ns.route("/explore/apps/<uuid:app_id>")

View File

@ -1,18 +1,28 @@
from flask import request
from pydantic import TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@ -169,7 +169,6 @@ console_ns.schema_model(
class TrialAppWorkflowRunApi(TrialAppResource):
@trial_feature_enable
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
def post(self, trial_app):
"""
@ -211,7 +210,6 @@ class TrialAppWorkflowRunApi(TrialAppResource):
class TrialAppWorkflowTaskStopApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app, task_id: str):
"""
Stop workflow task
@ -292,6 +290,7 @@ class TrialChatApi(TrialAppResource):
class TrialMessageSuggestedQuestionApi(TrialAppResource):
@trial_feature_enable
def get(self, trial_app, message_id):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
@ -471,6 +470,7 @@ class TrialCompletionApi(TrialAppResource):
class TrialSitApi(Resource):
"""Resource for trial app sites."""
@trial_feature_enable
@get_app_model_with_trial(None)
def get(self, app_model):
"""Retrieve app site info.
@ -492,6 +492,7 @@ class TrialSitApi(Resource):
class TrialAppParameterApi(Resource):
"""Resource for app variables."""
@trial_feature_enable
@get_app_model_with_trial(None)
def get(self, app_model):
"""Retrieve app parameters."""
@ -520,6 +521,7 @@ class TrialAppParameterApi(Resource):
class AppApi(Resource):
@trial_feature_enable
@get_app_model_with_trial(None)
@marshal_with(app_detail_with_site_model)
def get(self, app_model):
@ -532,6 +534,7 @@ class AppApi(Resource):
class AppWorkflowApi(Resource):
@trial_feature_enable
@get_app_model_with_trial(None)
@marshal_with(workflow_model)
def get(self, app_model):
@ -544,6 +547,7 @@ class AppWorkflowApi(Resource):
class DatasetListApi(Resource):
@trial_feature_enable
@get_app_model_with_trial(None)
def get(self, app_model):
page = request.args.get("page", default=1, type=int)

View File

@ -1,10 +1,11 @@
import logging
from typing import Any
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel
from werkzeug.exceptions import InternalServerError
from controllers.common.controller_schemas import WorkflowRunPayload
from controllers.common.schema import register_schema_model
from controllers.console.app.error import (
CompletionRequestError,
@ -33,6 +34,12 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
register_schema_model(console_ns, WorkflowRunPayload)

View File

@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate
from typing import Concatenate, ParamSpec, TypeVar
from flask import abort
from flask_restx import Resource
@ -15,8 +15,12 @@ from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def installed_app_required[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
@ -45,7 +49,7 @@ def installed_app_required[**P, R](view: Callable[Concatenate[InstalledApp, P],
return decorator
def user_allowed_to_access_app[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
@ -69,7 +73,7 @@ def user_allowed_to_access_app[**P, R](view: Callable[Concatenate[InstalledApp,
return decorator
def trial_app_required[**P, R](view: Callable[Concatenate[App, P], R] | None = None):
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, P], R]):
@wraps(view)
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
@ -102,7 +106,7 @@ def trial_app_required[**P, R](view: Callable[Concatenate[App, P], R] | None = N
return decorator
def trial_feature_enable[**P, R](view: Callable[P, R]):
def trial_feature_enable(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
@ -113,7 +117,7 @@ def trial_feature_enable[**P, R](view: Callable[P, R]):
return decorated
def explore_banner_enabled[**P, R](view: Callable[P, R]):
def explore_banner_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()

View File

@ -1,18 +1,15 @@
from datetime import datetime
from typing import Any
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from constants import HIDDEN_VALUE
from fields.base import ResponseModel
from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_account_with_tenant, login_required
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models
from ..common.schema import register_schema_models
from . import console_ns
from .wraps import account_initialization_required, setup_required
@ -27,52 +24,12 @@ class APIBasedExtensionPayload(BaseModel):
api_key: str = Field(description="API key for authentication")
class CodeBasedExtensionResponse(ResponseModel):
module: str = Field(description="Module name")
data: Any = Field(description="Extension data")
register_schema_models(console_ns, APIBasedExtensionPayload)
def _mask_api_key(api_key: str) -> str:
if not api_key:
return api_key
if len(api_key) <= 8:
return api_key[0] + "******" + api_key[-1]
return api_key[:3] + "******" + api_key[-3:]
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class APIBasedExtensionResponse(ResponseModel):
id: str
name: str
api_endpoint: str
api_key: str
created_at: int | None = None
@field_validator("api_key", mode="before")
@classmethod
def _normalize_api_key(cls, value: str) -> str:
return _mask_api_key(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
register_schema_models(console_ns, APIBasedExtensionPayload, CodeBasedExtensionResponse, APIBasedExtensionResponse)
console_ns.schema_model(
"APIBasedExtensionListResponse",
TypeAdapter(list[APIBasedExtensionResponse]).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _serialize_api_based_extension(extension: APIBasedExtension) -> dict[str, Any]:
return APIBasedExtensionResponse.model_validate(extension, from_attributes=True).model_dump(mode="json")
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
@console_ns.route("/code-based-extension")
@ -83,7 +40,10 @@ class CodeBasedExtensionAPI(Resource):
@console_ns.response(
200,
"Success",
console_ns.models[CodeBasedExtensionResponse.__name__],
console_ns.model(
"CodeBasedExtensionResponse",
{"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")},
),
)
@setup_required
@login_required
@ -91,34 +51,30 @@ class CodeBasedExtensionAPI(Resource):
def get(self):
query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return CodeBasedExtensionResponse(
module=query.module,
data=CodeBasedExtensionService.get_code_based_extension(query.module),
).model_dump(mode="json")
return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)}
@console_ns.route("/api-based-extension")
class APIBasedExtensionAPI(Resource):
@console_ns.doc("get_api_based_extensions")
@console_ns.doc(description="Get all API-based extensions for current tenant")
@console_ns.response(200, "Success", console_ns.models["APIBasedExtensionListResponse"])
@console_ns.response(200, "Success", api_based_extension_list_model)
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def get(self):
_, tenant_id = current_account_with_tenant()
return [
_serialize_api_based_extension(extension)
for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
]
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@console_ns.doc("create_api_based_extension")
@console_ns.doc(description="Create a new API-based extension")
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
@console_ns.response(201, "Extension created successfully", console_ns.models[APIBasedExtensionResponse.__name__])
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def post(self):
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant()
@ -130,7 +86,7 @@ class APIBasedExtensionAPI(Resource):
api_key=payload.api_key,
)
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data))
return APIBasedExtensionService.save(extension_data)
@console_ns.route("/api-based-extension/<uuid:id>")
@ -138,26 +94,26 @@ class APIBasedExtensionDetailAPI(Resource):
@console_ns.doc("get_api_based_extension")
@console_ns.doc(description="Get API-based extension by ID")
@console_ns.doc(params={"id": "Extension ID"})
@console_ns.response(200, "Success", console_ns.models[APIBasedExtensionResponse.__name__])
@console_ns.response(200, "Success", api_based_extension_model)
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def get(self, id):
api_based_extension_id = str(id)
_, tenant_id = current_account_with_tenant()
return _serialize_api_based_extension(
APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
)
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
@console_ns.doc("update_api_based_extension")
@console_ns.doc(description="Update API-based extension")
@console_ns.doc(params={"id": "Extension ID"})
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
@console_ns.response(200, "Extension updated successfully", console_ns.models[APIBasedExtensionResponse.__name__])
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def post(self, id):
api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant()
@ -172,7 +128,7 @@ class APIBasedExtensionDetailAPI(Resource):
if payload.api_key != HIDDEN_VALUE:
extension_data_from_db.api_key = payload.api_key
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data_from_db))
return APIBasedExtensionService.save(extension_data_from_db)
@console_ns.doc("delete_api_based_extension")
@console_ns.doc(description="Delete API-based extension")

View File

@ -7,8 +7,7 @@ import logging
from collections.abc import Generator
from flask import Response, jsonify, request
from flask_restx import Resource
from pydantic import BaseModel
from flask_restx import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
@ -34,11 +33,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
payload["expiration_time"] = int(form.expiration_time.timestamp())
@ -90,7 +84,10 @@ class ConsoleHumanInputFormApi(Resource):
"action": "Approve"
}
"""
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
current_user, _ = current_account_with_tenant()
service = HumanInputService(db.engine)
@ -110,8 +107,8 @@ class ConsoleHumanInputFormApi(Resource):
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=payload.action,
form_data=payload.inputs,
selected_action_id=args["action"],
form_data=args["inputs"],
submission_user_id=current_user.id,
)
@ -171,13 +168,12 @@ class ConsoleWorkflowEventsApi(Resource):
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
match app.mode:
case AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
case AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
case _:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
if app.mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app.mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"

View File

@ -1,6 +1,3 @@
from collections.abc import Mapping
from typing import TypedDict
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
@ -14,34 +11,9 @@ from services.billing_service import BillingService
_FALLBACK_LANG = "en-US"
class NotificationLangContent(TypedDict, total=False):
lang: str
title: str
subtitle: str
body: str
titlePicUrl: str
class NotificationItemDict(TypedDict):
notification_id: str | None
frequency: str | None
lang: str
title: str
subtitle: str
body: str
title_pic_url: str
class NotificationResponseDict(TypedDict):
should_show: bool
notifications: list[NotificationItemDict]
def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent:
def _pick_lang_content(contents: dict, lang: str) -> dict:
"""Return the single LangContent for *lang*, falling back to English."""
return (
contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent())
)
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
class DismissNotificationPayload(BaseModel):
@ -73,30 +45,28 @@ class NotificationApi(Resource):
result = BillingService.get_account_notification(str(current_user.id))
# Proto JSON uses camelCase field names (Kratos default marshaling).
response: NotificationResponseDict
if not result.get("shouldShow"):
response = {"should_show": False, "notifications": []}
return response, 200
return {"should_show": False, "notifications": []}, 200
lang = current_user.interface_language or _FALLBACK_LANG
notifications: list[NotificationItemDict] = []
notifications = []
for notification in result.get("notifications") or []:
contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {}
contents: dict = notification.get("contents") or {}
lang_content = _pick_lang_content(contents, lang)
item: NotificationItemDict = {
"notification_id": notification.get("notificationId"),
"frequency": notification.get("frequency"),
"lang": lang_content.get("lang", lang),
"title": lang_content.get("title", ""),
"subtitle": lang_content.get("subtitle", ""),
"body": lang_content.get("body", ""),
"title_pic_url": lang_content.get("titlePicUrl", ""),
}
notifications.append(item)
notifications.append(
{
"notification_id": notification.get("notificationId"),
"frequency": notification.get("frequency"),
"lang": lang_content.get("lang", lang),
"title": lang_content.get("title", ""),
"subtitle": lang_content.get("subtitle", ""),
"body": lang_content.get("body", ""),
"title_pic_url": lang_content.get("titlePicUrl", ""),
}
)
response = {"should_show": bool(notifications), "notifications": notifications}
return response, 200
return {"should_show": bool(notifications), "notifications": notifications}, 200
@console_ns.route("/notification/dismiss")

View File

@ -1,40 +1,43 @@
from typing import Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from flask_restx import Namespace, Resource, fields, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models.enums import TagType
from services.tag_service import (
SaveTagPayload,
TagBindingCreatePayload,
TagBindingDeletePayload,
TagService,
UpdateTagPayload,
)
from services.tag_service import TagService
dataset_tag_fields = {
"id": fields.String,
"name": fields.String,
"type": fields.String,
"binding_count": fields.String,
}
def build_dataset_tag_fields(api_or_ns: Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields)
class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50)
type: TagType = Field(description="Tag type")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
class TagBindingPayload(BaseModel):
tag_ids: list[str] = Field(description="Tag IDs to bind")
target_id: str = Field(description="Target ID to bind tags to")
type: TagType = Field(description="Tag type")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
class TagBindingRemovePayload(BaseModel):
tag_id: str = Field(description="Tag ID to remove")
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
class TagListQueryParam(BaseModel):
@ -42,36 +45,12 @@ class TagListQueryParam(BaseModel):
keyword: str | None = Field(None, description="Search keyword")
class TagResponse(ResponseModel):
id: str
name: str
type: str | None = None
binding_count: str | None = None
@field_validator("type", mode="before")
@classmethod
def normalize_type(cls, value: TagType | str | None) -> str | None:
if value is None:
return None
if isinstance(value, TagType):
return value.value
return value
@field_validator("binding_count", mode="before")
@classmethod
def normalize_binding_count(cls, value: int | str | None) -> str | None:
if value is None:
return None
return str(value)
register_schema_models(
console_ns,
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
TagListQueryParam,
TagResponse,
)
@ -83,18 +62,14 @@ class TagListApi(Resource):
@console_ns.doc(
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
)
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
@marshal_with(dataset_tag_fields)
def get(self):
_, current_tenant_id = current_account_with_tenant()
raw_args = request.args.to_dict()
param = TagListQueryParam.model_validate(raw_args)
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
serialized_tags = [
TagResponse.model_validate(tag, from_attributes=True).model_dump(mode="json") for tag in tags
]
return serialized_tags, 200
return tags, 200
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required
@ -107,11 +82,9 @@ class TagListApi(Resource):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
tag = TagService.save_tags(payload.model_dump())
response = TagResponse.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
).model_dump(mode="json")
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200
@ -130,13 +103,11 @@ class TagUpdateDeleteApi(Resource):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=payload.type), tag_id)
tag = TagService.update_tags(payload.model_dump(), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
response = TagResponse.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
).model_dump(mode="json")
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
return response, 200
@ -165,9 +136,7 @@ class TagBindingCreateApi(Resource):
raise Forbidden()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
)
TagService.save_tag_binding(payload.model_dump())
return {"result": "success"}, 200
@ -185,8 +154,6 @@ class TagBindingDeleteApi(Resource):
raise Forbidden()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
)
TagService.delete_tag_binding(payload.model_dump())
return {"result": "success"}, 200

View File

@ -1,7 +1,7 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
@ -9,25 +9,28 @@ from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.account import TenantPluginPermission
P = ParamSpec("P")
R = TypeVar("R")
def plugin_permission_required(
install_required: bool = False,
debug_required: bool = False,
):
def interceptor[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated(*args: P.args, **kwargs: P.kwargs):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
tenant_id = current_tenant_id
with sessionmaker(db.engine).begin() as session:
permission = session.scalar(
select(TenantPluginPermission)
permission = (
session.query(TenantPluginPermission)
.where(
TenantPluginPermission.tenant_id == tenant_id,
)
.limit(1)
.first()
)
if not permission:
@ -35,24 +38,22 @@ def plugin_permission_required(
return view(*args, **kwargs)
if install_required:
match permission.install_permission:
case TenantPluginPermission.InstallPermission.NOBODY:
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
case TenantPluginPermission.InstallPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
case TenantPluginPermission.InstallPermission.EVERYONE:
pass
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
pass
if debug_required:
match permission.debug_permission:
case TenantPluginPermission.DebugPermission.NOBODY:
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
case TenantPluginPermission.DebugPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
case TenantPluginPermission.DebugPermission.EVERYONE:
pass
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
pass
return view(*args, **kwargs)

View File

@ -1,13 +1,14 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Literal
from typing import Literal
import pytz
from flask import request
from flask_restx import Resource
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from constants.languages import supported_language
@ -37,10 +38,9 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.member_fields import Account as AccountResponse
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, extract_remote_ip, timezone
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from models.account import AccountStatus, InvitationCodeStatus
@ -175,61 +175,21 @@ reg(CheckEmailUniquePayload)
register_schema_models(console_ns, AccountResponse)
def _serialize_account(account) -> dict[str, Any]:
def _serialize_account(account) -> dict:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
integrate_fields = {
"provider": fields.String,
"created_at": TimestampField,
"is_bound": fields.Boolean,
"link": fields.String,
}
class AccountIntegrateResponse(ResponseModel):
provider: str
created_at: int | None = None
is_bound: bool
link: str | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AccountIntegrateListResponse(ResponseModel):
data: list[AccountIntegrateResponse]
class EducationVerifyResponse(ResponseModel):
token: str | None = None
class EducationStatusResponse(ResponseModel):
result: bool | None = None
is_student: bool | None = None
expire_at: int | None = None
allow_refresh: bool | None = None
@field_validator("expire_at", mode="before")
@classmethod
def _normalize_expire_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class EducationAutocompleteResponse(ResponseModel):
data: list[str] = Field(default_factory=list)
curr_page: int | None = None
has_next: bool | None = None
register_schema_models(
console_ns,
AccountIntegrateResponse,
AccountIntegrateListResponse,
EducationVerifyResponse,
EducationStatusResponse,
EducationAutocompleteResponse,
integrate_model = console_ns.model("AccountIntegrate", integrate_fields)
integrate_list_model = console_ns.model(
"AccountIntegrateList",
{"data": fields.List(fields.Nested(integrate_model))},
)
@ -400,7 +360,7 @@ class AccountIntegrateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__])
@marshal_with(integrate_list_model)
def get(self):
account, _ = current_account_with_tenant()
@ -436,9 +396,7 @@ class AccountIntegrateApi(Resource):
}
)
return AccountIntegrateListResponse(
data=[AccountIntegrateResponse.model_validate(item) for item in integrate_data]
).model_dump(mode="json")
return {"data": integrate_data}
@console_ns.route("/account/delete/verify")
@ -490,22 +448,31 @@ class AccountDeleteUpdateFeedbackApi(Resource):
@console_ns.route("/account/education/verify")
class EducationVerifyApi(Resource):
verify_fields = {
"token": fields.String,
}
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__])
@marshal_with(verify_fields)
def get(self):
account, _ = current_account_with_tenant()
return EducationVerifyResponse.model_validate(
BillingService.EducationIdentity.verify(account.id, account.email) or {}
).model_dump(mode="json")
return BillingService.EducationIdentity.verify(account.id, account.email)
@console_ns.route("/account/education")
class EducationApi(Resource):
status_fields = {
"result": fields.Boolean,
"is_student": fields.Boolean,
"expire_at": TimestampField,
"allow_refresh": fields.Boolean,
}
@console_ns.expect(console_ns.models[EducationActivatePayload.__name__])
@setup_required
@login_required
@ -525,33 +492,37 @@ class EducationApi(Resource):
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__])
@marshal_with(status_fields)
def get(self):
account, _ = current_account_with_tenant()
res = BillingService.EducationIdentity.status(account.id) or {}
res = BillingService.EducationIdentity.status(account.id)
# convert expire_at to UTC timestamp from isoformat
if res and "expire_at" in res:
res["expire_at"] = datetime.fromisoformat(res["expire_at"]).astimezone(pytz.utc)
return EducationStatusResponse.model_validate(res).model_dump(mode="json")
return res
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource):
data_fields = {
"data": fields.List(fields.String),
"curr_page": fields.Integer,
"has_next": fields.Boolean,
}
@console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@console_ns.response(200, "Success", console_ns.models[EducationAutocompleteResponse.__name__])
@marshal_with(data_fields)
def get(self):
payload = request.args.to_dict(flat=True)
args = EducationAutocompleteQuery.model_validate(payload)
return EducationAutocompleteResponse.model_validate(
BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) or {}
).model_dump(mode="json")
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
@console_ns.route("/account/change-email")
@ -591,7 +562,8 @@ class ChangeEmailSendEmailApi(Resource):
user_email = current_user.email
else:
account = AccountService.get_account_by_email_with_case_fallback(args.email)
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
if account is None:
raise AccountNotFound()
email_for_sending = account.email

View File

@ -465,7 +465,7 @@ class ModelProviderModelDisableApi(Resource):
class ParserValidate(BaseModel):
model: str
model_type: ModelType
credentials: dict[str, Any]
credentials: dict
console_ns.schema_model(

View File

@ -1,9 +1,8 @@
import logging
from datetime import datetime
from flask import request
from flask_restx import Resource, fields, marshal
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
@ -27,10 +26,9 @@ from controllers.console.wraps import (
)
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
from services.billing_service import BillingService, SubscriptionPlan
from services.enterprise.enterprise_service import EnterpriseService
@ -60,37 +58,6 @@ class WorkspaceInfoPayload(BaseModel):
name: str
class TenantInfoResponse(ResponseModel):
id: str
name: str | None = None
plan: str | None = None
status: str | None = None
created_at: int | None = None
role: str | None = None
in_trial: bool | None = None
trial_end_reason: str | None = None
custom_config: dict | None = None
trial_credits: int | None = None
trial_credits_used: int | None = None
next_credit_reset_date: int | None = None
@field_validator("plan", "status", "trial_end_reason", mode="before")
@classmethod
def _normalize_enum_like(cls, value):
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None):
if isinstance(value, datetime):
return int(value.timestamp())
return value
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@ -99,7 +66,6 @@ reg(WorkspaceListQuery)
reg(SwitchWorkspacePayload)
reg(WorkspaceCustomConfigPayload)
reg(WorkspaceInfoPayload)
reg(TenantInfoResponse)
provider_fields = {
"provider_name": fields.String,
@ -214,7 +180,7 @@ class TenantApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__])
@marshal_with(tenant_fields)
def post(self):
if request.path == "/info":
logger.warning("Deprecated URL /info was used.")
@ -234,13 +200,7 @@ class TenantApi(Resource):
else:
raise Unauthorized("workspace is archived")
return (
TenantInfoResponse.model_validate(
WorkspaceService.get_tenant_info(tenant),
from_attributes=True,
).model_dump(mode="json"),
200,
)
return WorkspaceService.get_tenant_info(tenant), 200
@console_ns.route("/workspaces/switch")
@ -280,10 +240,8 @@ class CustomConfigWorkspaceApi(Resource):
args = WorkspaceCustomConfigPayload.model_validate(payload)
tenant = db.get_or_404(Tenant, current_tenant_id)
custom_config_dict: TenantCustomConfigDict = {
"remove_webapp_brand": args.remove_webapp_brand
if args.remove_webapp_brand is not None
else tenant.custom_config_dict.get("remove_webapp_brand", False),
custom_config_dict = {
"remove_webapp_brand": args.remove_webapp_brand,
"replace_webapp_logo": args.replace_webapp_logo
if args.replace_webapp_logo is not None
else tenant.custom_config_dict.get("replace_webapp_logo"),

Some files were not shown because too many files have changed in this diff Show More