mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 03:59:30 +08:00
Compare commits
146 Commits
build/bool
...
pinecone
| Author | SHA1 | Date | |
|---|---|---|---|
| 594906c1ff | |||
| 80f8245f2e | |||
| a12b437c16 | |||
| 12de554313 | |||
| 1f36c0c1c5 | |||
| 8b9297563c | |||
| 1cbe9eedb6 | |||
| 90fc5a1f12 | |||
| 41dfdf1ac0 | |||
| dd7de74aa6 | |||
| f11131f8b5 | |||
| 2e6e414a9e | |||
| c45d676477 | |||
| b8d8dddd5a | |||
| c45c22b1b2 | |||
| 3d57a9ccdc | |||
| cb04c21141 | |||
| f70272f638 | |||
| b4b71ded47 | |||
| 24e2b72b71 | |||
| 529791ce62 | |||
| b66945b9b8 | |||
| f3c5d77ad5 | |||
| e5e42bc483 | |||
| bdfbfa391f | |||
| 72acd9b483 | |||
| 9f528d23d4 | |||
| d937cc491d | |||
| 863f3aeb27 | |||
| 0fe078d25e | |||
| d9420c7224 | |||
| 9ff6baaf52 | |||
| 574d00bb13 | |||
| 8d60e5c342 | |||
| d9eb1a73af | |||
| 1a34ff8a67 | |||
| 14e7ba4818 | |||
| 52e9bcbfdb | |||
| 20ae3eae54 | |||
| 0fb145e667 | |||
| bcac43c812 | |||
| 929d9e0b3f | |||
| d5e560a987 | |||
| e4383d6167 | |||
| f32e176d6a | |||
| 3d5a4df9d0 | |||
| e47bfd2ca3 | |||
| f8f768873e | |||
| d043e1a05a | |||
| 837c0ddacc | |||
| 7c340695d6 | |||
| e87d4fbf69 | |||
| 39064197da | |||
| c4496e6cf2 | |||
| 27d09d1783 | |||
| a174ee419e | |||
| 79e6138ce2 | |||
| 5a64f69456 | |||
| 5c01dd97e8 | |||
| ecf74d91e2 | |||
| 62892ed8d7 | |||
| 7b399cc5e5 | |||
| fab5740778 | |||
| 30f2d756a7 | |||
| 0d745c64d8 | |||
| 56c51f0112 | |||
| 4adf85d7d4 | |||
| 7995ff1410 | |||
| d2f234757b | |||
| bf34437186 | |||
| 47f02eec96 | |||
| 06dd4d6e00 | |||
| fbceda7b66 | |||
| 9d6ce3065d | |||
| bb718acadf | |||
| 4cd00efe3b | |||
| 22b11e4b43 | |||
| 2a29c61041 | |||
| 34b041e9f0 | |||
| 917ed8cf84 | |||
| 85b0b8373b | |||
| f04844435f | |||
| 421a3284bc | |||
| d4883256f1 | |||
| 372074edba | |||
| 826f19e968 | |||
| c06cfcbb5a | |||
| ddf6192643 | |||
| 58189ed9a0 | |||
| 726c429772 | |||
| a159c13333 | |||
| 5df3a4eb98 | |||
| 244ed5e5e3 | |||
| 249e9a10a1 | |||
| d4dba373cb | |||
| 06f249a3b4 | |||
| 8809b3ba4d | |||
| 66e80f9dac | |||
| b486d72b8e | |||
| d9e26eba65 | |||
| a7419d0aba | |||
| d5e6e38c58 | |||
| a2598fd134 | |||
| 58165c3951 | |||
| dac72b078d | |||
| b5c2756261 | |||
| fa753239ad | |||
| 8af2ae973f | |||
| 23a8409e0c | |||
| 6e674b511a | |||
| 47f480c0dc | |||
| bfc4fe1a9a | |||
| 98473e9d4f | |||
| 13d3271ec0 | |||
| 6727ff6dbe | |||
| 04954918a5 | |||
| eb3a031964 | |||
| 243876e9b7 | |||
| 884fdc2fa8 | |||
| 410fe7293f | |||
| d7869a4d1e | |||
| aa71f88e1b | |||
| cfb8d224da | |||
| c14b498676 | |||
| ac5aed7a45 | |||
| abb86753c1 | |||
| f6cfe80bf5 | |||
| 2b91ba2411 | |||
| b4be132201 | |||
| 99fec40117 | |||
| 3df04c7e9a | |||
| e7833b42cd | |||
| c64b9c941a | |||
| 1d776c4cd0 | |||
| d1ba5fec89 | |||
| 9260aa3445 | |||
| 6010d5f24c | |||
| b08bfa203a | |||
| a06681913d | |||
| 424fdf4b52 | |||
| bcf42362e3 | |||
| a4d17cb585 | |||
| a9e106b17e | |||
| 044ad5100e | |||
| 3032e6fe59 | |||
| 4eba2ee92b |
19
.claude/settings.json.example
Normal file
19
.claude/settings.json.example
Normal file
@ -0,0 +1,19 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [],
|
||||
"deny": []
|
||||
},
|
||||
"env": {
|
||||
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
},
|
||||
"enabledMcpjsonServers": [
|
||||
"context7",
|
||||
"sequential-thinking",
|
||||
"github",
|
||||
"fetch",
|
||||
"playwright",
|
||||
"ide"
|
||||
],
|
||||
"enableAllProjectMcpServers": true
|
||||
}
|
||||
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
npm add -g pnpm@10.15.0
|
||||
corepack enable
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
|
||||
34
.github/actions/setup-uv/action.yml
vendored
34
.github/actions/setup-uv/action.yml
vendored
@ -1,34 +0,0 @@
|
||||
name: Setup UV and Python
|
||||
|
||||
inputs:
|
||||
python-version:
|
||||
description: Python version to use and the UV installed with
|
||||
required: true
|
||||
default: '3.12'
|
||||
uv-version:
|
||||
description: UV version to set up
|
||||
required: true
|
||||
default: '0.8.9'
|
||||
uv-lockfile:
|
||||
description: Path to the UV lockfile to restore cache from
|
||||
required: true
|
||||
default: ''
|
||||
enable-cache:
|
||||
required: true
|
||||
default: true
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Set up Python ${{ inputs.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
version: ${{ inputs.uv-version }}
|
||||
python-version: ${{ inputs.python-version }}
|
||||
enable-cache: ${{ inputs.enable-cache }}
|
||||
cache-dependency-glob: ${{ inputs.uv-lockfile }}
|
||||
13
.github/workflows/api-tests.yml
vendored
13
.github/workflows/api-tests.yml
vendored
@ -1,13 +1,7 @@
|
||||
name: Run Pytest
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- api/**
|
||||
- docker/**
|
||||
- .github/workflows/api-tests.yml
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: api-tests-${{ github.head_ref || github.run_id }}
|
||||
@ -33,10 +27,11 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: ./.github/actions/setup-uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
uv-lockfile: api/uv.lock
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Check UV lockfile
|
||||
run: uv lock --project api --check
|
||||
|
||||
9
.github/workflows/autofix.yml
vendored
9
.github/workflows/autofix.yml
vendored
@ -1,9 +1,7 @@
|
||||
name: autofix.ci
|
||||
on:
|
||||
workflow_call:
|
||||
pull_request:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
branches: ["main"]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
@ -15,7 +13,9 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
# Use uv to ensure we have the same ruff version in CI and locally.
|
||||
- uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- run: |
|
||||
cd api
|
||||
uv sync --dev
|
||||
@ -26,6 +26,7 @@ jobs:
|
||||
- name: ast-grep
|
||||
run: |
|
||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
- name: mdformat
|
||||
run: |
|
||||
uvx mdformat .
|
||||
|
||||
20
.github/workflows/db-migration-test.yml
vendored
20
.github/workflows/db-migration-test.yml
vendored
@ -1,13 +1,7 @@
|
||||
name: DB Migration Test
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- plugins/beta
|
||||
paths:
|
||||
- api/migrations/**
|
||||
- .github/workflows/db-migration-test.yml
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: db-migration-test-${{ github.ref }}
|
||||
@ -25,12 +19,20 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: ./.github/actions/setup-uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
uv-lockfile: api/uv.lock
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api
|
||||
- name: Ensure Offline migration are supported
|
||||
run: |
|
||||
# upgrade
|
||||
uv run --directory api flask db upgrade 'base:head' --sql
|
||||
# downgrade
|
||||
uv run --directory api flask db downgrade 'head:base' --sql
|
||||
|
||||
- name: Prepare middleware env
|
||||
run: |
|
||||
|
||||
78
.github/workflows/main-ci.yml
vendored
Normal file
78
.github/workflows/main-ci.yml
vendored
Normal file
@ -0,0 +1,78 @@
|
||||
name: Main CI Pipeline
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
push:
|
||||
branches: ["main"]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
checks: write
|
||||
statuses: write
|
||||
|
||||
concurrency:
|
||||
group: main-ci-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
# Check which paths were changed to determine which tests to run
|
||||
check-changes:
|
||||
name: Check Changed Files
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
api-changed: ${{ steps.changes.outputs.api }}
|
||||
web-changed: ${{ steps.changes.outputs.web }}
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dorny/paths-filter@v3
|
||||
id: changes
|
||||
with:
|
||||
filters: |
|
||||
api:
|
||||
- 'api/**'
|
||||
- 'docker/**'
|
||||
- '.github/workflows/api-tests.yml'
|
||||
web:
|
||||
- 'web/**'
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'docker/**'
|
||||
- '.github/workflows/vdb-tests.yml'
|
||||
- 'api/uv.lock'
|
||||
- 'api/pyproject.toml'
|
||||
migration:
|
||||
- 'api/migrations/**'
|
||||
- '.github/workflows/db-migration-test.yml'
|
||||
|
||||
# Run tests in parallel
|
||||
api-tests:
|
||||
name: API Tests
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.api-changed == 'true'
|
||||
uses: ./.github/workflows/api-tests.yml
|
||||
|
||||
web-tests:
|
||||
name: Web Tests
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.web-changed == 'true'
|
||||
uses: ./.github/workflows/web-tests.yml
|
||||
|
||||
style-check:
|
||||
name: Style Check
|
||||
uses: ./.github/workflows/style.yml
|
||||
|
||||
vdb-tests:
|
||||
name: VDB Tests
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.vdb-changed == 'true'
|
||||
uses: ./.github/workflows/vdb-tests.yml
|
||||
|
||||
db-migration-test:
|
||||
name: DB Migration Test
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.migration-changed == 'true'
|
||||
uses: ./.github/workflows/db-migration-test.yml
|
||||
20
.github/workflows/style.yml
vendored
20
.github/workflows/style.yml
vendored
@ -1,9 +1,7 @@
|
||||
name: Style check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: style-${{ github.head_ref || github.run_id }}
|
||||
@ -36,30 +34,20 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
uv-lockfile: api/uv.lock
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Ruff check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
uv run --directory api ruff --version
|
||||
uv run --directory api ruff check ./
|
||||
uv run --directory api ruff format --check ./
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
|
||||
|
||||
web-style:
|
||||
name: Web Style
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
15
.github/workflows/vdb-tests.yml
vendored
15
.github/workflows/vdb-tests.yml
vendored
@ -1,15 +1,7 @@
|
||||
name: Run VDB Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- api/core/rag/datasource/**
|
||||
- docker/**
|
||||
- .github/workflows/vdb-tests.yml
|
||||
- api/uv.lock
|
||||
- api/pyproject.toml
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-${{ github.head_ref || github.run_id }}
|
||||
@ -39,10 +31,11 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: ./.github/actions/setup-uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
uv-lockfile: api/uv.lock
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Check UV lockfile
|
||||
run: uv lock --project api --check
|
||||
|
||||
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@ -1,11 +1,7 @@
|
||||
name: Web Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- web/**
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: web-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
34
.mcp.json
Normal file
34
.mcp.json
Normal file
@ -0,0 +1,34 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"context7": {
|
||||
"type": "http",
|
||||
"url": "https://mcp.context7.com/mcp"
|
||||
},
|
||||
"sequential-thinking": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
|
||||
"env": {}
|
||||
},
|
||||
"github": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-github"],
|
||||
"env": {
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
|
||||
}
|
||||
},
|
||||
"fetch": {
|
||||
"type": "stdio",
|
||||
"command": "uvx",
|
||||
"args": ["mcp-server-fetch"],
|
||||
"env": {}
|
||||
},
|
||||
"playwright": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@playwright/mcp@latest"],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -86,3 +86,4 @@ pnpm test # Run Jest tests
|
||||
## Project-Specific Conventions
|
||||
|
||||
- All async tasks use Celery with Redis as broker
|
||||
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.
|
||||
|
||||
@ -180,7 +180,7 @@ docker compose up -d
|
||||
|
||||
## Contributing
|
||||
|
||||
对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
|
||||
对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_CN.md)。
|
||||
同时,请考虑通过社交媒体、活动和会议来支持 Dify 的分享。
|
||||
|
||||
> 我们正在寻找贡献者来帮助将 Dify 翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)获取更多信息,并在我们的[Discord 社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。
|
||||
|
||||
@ -173,7 +173,7 @@ Stellen Sie Dify mit einem Klick in AKS bereit, indem Sie [Azure Devops Pipeline
|
||||
|
||||
## Contributing
|
||||
|
||||
Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren.
|
||||
Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_DE.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren.
|
||||
|
||||
> Wir suchen Mitwirkende, die dabei helfen, Dify in weitere Sprachen zu übersetzen – außer Mandarin oder Englisch. Wenn Sie Interesse an einer Mitarbeit haben, lesen Sie bitte die [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) für weitere Informationen und hinterlassen Sie einen Kommentar im `global-users`-Kanal unseres [Discord Community Servers](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
|
||||
@ -170,7 +170,7 @@ Implementa Dify en AKS con un clic usando [Azure Devops Pipeline Helm Chart by @
|
||||
|
||||
## Contribuir
|
||||
|
||||
Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_ES.md).
|
||||
Al mismo tiempo, considera apoyar a Dify compartiéndolo en redes sociales y en eventos y conferencias.
|
||||
|
||||
> Estamos buscando colaboradores para ayudar con la traducción de Dify a idiomas que no sean el mandarín o el inglés. Si estás interesado en ayudar, consulta el [README de i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para obtener más información y déjanos un comentario en el canal `global-users` de nuestro [Servidor de Comunidad en Discord](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
@ -168,7 +168,7 @@ Déployez Dify sur AKS en un clic en utilisant [Azure Devops Pipeline Helm Chart
|
||||
|
||||
## Contribuer
|
||||
|
||||
Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_FR.md).
|
||||
Dans le même temps, veuillez envisager de soutenir Dify en le partageant sur les réseaux sociaux et lors d'événements et de conférences.
|
||||
|
||||
> Nous recherchons des contributeurs pour aider à traduire Dify dans des langues autres que le mandarin ou l'anglais. Si vous êtes intéressé à aider, veuillez consulter le [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) pour plus d'informations, et laissez-nous un commentaire dans le canal `global-users` de notre [Serveur communautaire Discord](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
@ -169,7 +169,7 @@ docker compose up -d
|
||||
|
||||
## 貢献
|
||||
|
||||
コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。
|
||||
コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_JA.md)を参照してください。
|
||||
同時に、DifyをSNSやイベント、カンファレンスで共有してサポートしていただけると幸いです。
|
||||
|
||||
> Difyを英語または中国語以外の言語に翻訳してくれる貢献者を募集しています。興味がある場合は、詳細については[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)を参照してください。また、[Discordコミュニティサーバー](https://discord.gg/8Tpq4AcN9c)の`global-users`チャンネルにコメントを残してください。
|
||||
|
||||
@ -162,7 +162,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
|
||||
|
||||
## 기여
|
||||
|
||||
코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요.
|
||||
코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_KR.md)를 참조하세요.
|
||||
동시에 Dify를 소셜 미디어와 행사 및 컨퍼런스에 공유하여 지원하는 것을 고려해 주시기 바랍니다.
|
||||
|
||||
> 우리는 Dify를 중국어나 영어 이외의 언어로 번역하는 데 도움을 줄 수 있는 기여자를 찾고 있습니다. 도움을 주고 싶으시다면 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)에서 더 많은 정보를 확인하시고 [Discord 커뮤니티 서버](https://discord.gg/8Tpq4AcN9c)의 `global-users` 채널에 댓글을 남겨주세요.
|
||||
|
||||
@ -168,7 +168,7 @@ Implante o Dify no AKS com um clique usando [Azure Devops Pipeline Helm Chart by
|
||||
|
||||
## Contribuindo
|
||||
|
||||
Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_PT.md).
|
||||
Ao mesmo tempo, considere apoiar o Dify compartilhando-o nas redes sociais e em eventos e conferências.
|
||||
|
||||
> Estamos buscando contribuidores para ajudar na tradução do Dify para idiomas além de Mandarim e Inglês. Se você tiver interesse em ajudar, consulte o [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para mais informações e deixe-nos um comentário no canal `global-users` em nosso [Servidor da Comunidade no Discord](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
@ -161,7 +161,7 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
|
||||
|
||||
## Katkıda Bulunma
|
||||
|
||||
Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz.
|
||||
Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_TR.md) bakabilirsiniz.
|
||||
Aynı zamanda, lütfen Dify'ı sosyal medyada, etkinliklerde ve konferanslarda paylaşarak desteklemeyi düşünün.
|
||||
|
||||
> Dify'ı Mandarin veya İngilizce dışındaki dillere çevirmemize yardımcı olacak katkıda bulunanlara ihtiyacımız var. Yardımcı olmakla ilgileniyorsanız, lütfen daha fazla bilgi için [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) dosyasına bakın ve [Discord Topluluk Sunucumuzdaki](https://discord.gg/8Tpq4AcN9c) `global-users` kanalında bize bir yorum bırakın.
|
||||
|
||||
@ -173,7 +173,7 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify
|
||||
|
||||
## 貢獻
|
||||
|
||||
對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
|
||||
對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_TW.md)。
|
||||
同時,也請考慮透過在社群媒體和各種活動與會議上分享 Dify 來支持我們。
|
||||
|
||||
> 我們正在尋找貢獻者協助將 Dify 翻譯成中文和英文以外的語言。如果您有興趣幫忙,請查看 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) 獲取更多資訊,並在我們的 [Discord 社群伺服器](https://discord.gg/8Tpq4AcN9c) 的 `global-users` 頻道留言給我們。
|
||||
|
||||
@ -162,7 +162,7 @@ Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure De
|
||||
|
||||
## Đóng góp
|
||||
|
||||
Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi.
|
||||
Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_VI.md) của chúng tôi.
|
||||
Đồng thời, vui lòng xem xét hỗ trợ Dify bằng cách chia sẻ nó trên mạng xã hội và tại các sự kiện và hội nghị.
|
||||
|
||||
> Chúng tôi đang tìm kiếm người đóng góp để giúp dịch Dify sang các ngôn ngữ khác ngoài tiếng Trung hoặc tiếng Anh. Nếu bạn quan tâm đến việc giúp đỡ, vui lòng xem [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) để biết thêm thông tin và để lại bình luận cho chúng tôi trong kênh `global-users` của [Máy chủ Cộng đồng Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi.
|
||||
|
||||
@ -156,7 +156,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
|
||||
# Vector database configuration
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `pinecone`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
@ -361,6 +361,17 @@ PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
|
||||
|
||||
|
||||
# Pinecone configuration, only available when VECTOR_STORE is `pinecone`
|
||||
PINECONE_API_KEY=your-pinecone-api-key
|
||||
PINECONE_ENVIRONMENT=your-pinecone-environment
|
||||
PINECONE_INDEX_NAME=dify-index
|
||||
PINECONE_CLIENT_TIMEOUT=30
|
||||
PINECONE_BATCH_SIZE=100
|
||||
PINECONE_METRIC=cosine
|
||||
PINECONE_PODS=1
|
||||
PINECONE_POD_TYPE=s1
|
||||
|
||||
# Mail configuration, support: resend, smtp, sendgrid
|
||||
MAIL_TYPE=
|
||||
# If using SendGrid, use the 'from' field for authentication if necessary.
|
||||
@ -564,3 +575,7 @@ QUEUE_MONITOR_THRESHOLD=200
|
||||
QUEUE_MONITOR_ALERT_EMAILS=
|
||||
# Monitor interval in minutes, default is 30 minutes
|
||||
QUEUE_MONITOR_INTERVAL=30
|
||||
|
||||
# Swagger UI configuration
|
||||
SWAGGER_UI_ENABLED=true
|
||||
SWAGGER_UI_PATH=/swagger-ui.html
|
||||
|
||||
@ -43,6 +43,7 @@ select = [
|
||||
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
||||
"S311", # suspicious-non-cryptographic-random-usage
|
||||
"G001", # don't use str format to logging messages
|
||||
"G003", # don't use + in logging messages
|
||||
"G004", # don't use f-strings to format logging messages
|
||||
]
|
||||
|
||||
|
||||
@ -97,8 +97,16 @@ uv run celery -A app.celery beat
|
||||
uv sync --dev
|
||||
```
|
||||
|
||||
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
|
||||
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md)
|
||||
|
||||
```bash
|
||||
uv run -P api bash dev/pytest/pytest_all_tests.sh
|
||||
uv run pytest # Run all tests
|
||||
uv run pytest tests/unit_tests/ # Unit tests only
|
||||
uv run pytest tests/integration_tests/ # Integration tests
|
||||
|
||||
# Code quality
|
||||
../dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run mypy . # Type checking
|
||||
```
|
||||
|
||||
@ -5,6 +5,8 @@ from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from dify_app import DifyApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
@ -32,7 +34,7 @@ def create_app() -> DifyApp:
|
||||
initialize_extensions(app)
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logging.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
|
||||
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
|
||||
return app
|
||||
|
||||
|
||||
@ -93,14 +95,14 @@ def initialize_extensions(app: DifyApp):
|
||||
is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True
|
||||
if not is_enabled:
|
||||
if dify_config.DEBUG:
|
||||
logging.info("Skipped %s", short_name)
|
||||
logger.info("Skipped %s", short_name)
|
||||
continue
|
||||
|
||||
start_time = time.perf_counter()
|
||||
ext.init_app(app)
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logging.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
|
||||
logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
|
||||
|
||||
|
||||
def create_migrations_app():
|
||||
|
||||
@ -38,6 +38,8 @@ from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@click.command("reset-password", help="Reset the account password.")
|
||||
@click.option("--email", prompt=True, help="Account email to reset password for")
|
||||
@ -685,7 +687,7 @@ def upgrade_db():
|
||||
click.echo(click.style("Database migration successful!", fg="green"))
|
||||
|
||||
except Exception:
|
||||
logging.exception("Failed to execute database migration")
|
||||
logger.exception("Failed to execute database migration")
|
||||
finally:
|
||||
lock.release()
|
||||
else:
|
||||
@ -733,7 +735,7 @@ where sites.id is null limit 1000"""
|
||||
except Exception:
|
||||
failed_app_ids.append(app_id)
|
||||
click.echo(click.style(f"Failed to fix missing site for app {app_id}", fg="red"))
|
||||
logging.exception("Failed to fix app related site missing issue, app_id: %s", app_id)
|
||||
logger.exception("Failed to fix app related site missing issue, app_id: %s", app_id)
|
||||
continue
|
||||
|
||||
if not processed_count:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Literal, Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
@ -976,6 +976,18 @@ class WorkflowLogConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class SwaggerUIConfig(BaseSettings):
|
||||
SWAGGER_UI_ENABLED: bool = Field(
|
||||
description="Whether to enable Swagger UI in api module",
|
||||
default=True,
|
||||
)
|
||||
|
||||
SWAGGER_UI_PATH: str = Field(
|
||||
description="Swagger UI page path in api module",
|
||||
default="/swagger-ui.html",
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@ -1007,6 +1019,7 @@ class FeatureConfig(
|
||||
WorkspaceConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
# hosted services config
|
||||
HostedServiceConfig,
|
||||
CeleryBeatConfig,
|
||||
|
||||
@ -35,6 +35,7 @@ from .vdb.opensearch_config import OpenSearchConfig
|
||||
from .vdb.oracle_config import OracleConfig
|
||||
from .vdb.pgvector_config import PGVectorConfig
|
||||
from .vdb.pgvectors_config import PGVectoRSConfig
|
||||
from .vdb.pinecone_config import PineconeConfig
|
||||
from .vdb.qdrant_config import QdrantConfig
|
||||
from .vdb.relyt_config import RelytConfig
|
||||
from .vdb.tablestore_config import TableStoreConfig
|
||||
@ -215,6 +216,7 @@ class DatabaseConfig(BaseSettings):
|
||||
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
|
||||
"connect_args": connect_args,
|
||||
"pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO,
|
||||
"pool_reset_on_return": None,
|
||||
}
|
||||
|
||||
|
||||
@ -330,6 +332,7 @@ class MiddlewareConfig(
|
||||
PGVectorConfig,
|
||||
VastbaseVectorConfig,
|
||||
PGVectoRSConfig,
|
||||
PineconeConfig,
|
||||
QdrantConfig,
|
||||
RelytConfig,
|
||||
TencentVectorDBConfig,
|
||||
|
||||
41
api/configs/middleware/vdb/pinecone_config.py
Normal file
41
api/configs/middleware/vdb/pinecone_config.py
Normal file
@ -0,0 +1,41 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class PineconeConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Pinecone vector database
|
||||
"""
|
||||
|
||||
PINECONE_API_KEY: Optional[str] = Field(
|
||||
description="API key for authenticating with Pinecone service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PINECONE_ENVIRONMENT: Optional[str] = Field(
|
||||
description="Pinecone environment (e.g., 'us-west1-gcp', 'us-east-1-aws')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PINECONE_INDEX_NAME: Optional[str] = Field(
|
||||
description="Default Pinecone index name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PINECONE_CLIENT_TIMEOUT: PositiveInt = Field(
|
||||
description="Timeout in seconds for Pinecone client operations (default is 30 seconds)",
|
||||
default=30,
|
||||
)
|
||||
|
||||
PINECONE_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Batch size for Pinecone operations (default is 100)",
|
||||
default=100,
|
||||
)
|
||||
|
||||
PINECONE_METRIC: str = Field(
|
||||
description="Distance metric for Pinecone index (cosine, euclidean, dotproduct)",
|
||||
default="cosine",
|
||||
)
|
||||
|
||||
@ -70,7 +70,7 @@ from .app import (
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server
|
||||
|
||||
# Import billing controllers
|
||||
from .billing import billing, compliance
|
||||
@ -84,7 +84,6 @@ from .datasets import (
|
||||
external,
|
||||
hit_testing,
|
||||
metadata,
|
||||
upload_file,
|
||||
website,
|
||||
)
|
||||
|
||||
|
||||
@ -31,6 +31,8 @@ from services.errors.audio import (
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatMessageAudioApi(Resource):
|
||||
@setup_required
|
||||
@ -49,7 +51,7 @@ class ChatMessageAudioApi(Resource):
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
@ -70,7 +72,7 @@ class ChatMessageAudioApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("Failed to handle post request to ChatMessageAudioApi")
|
||||
logger.exception("Failed to handle post request to ChatMessageAudioApi")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -97,7 +99,7 @@ class ChatMessageTextApi(Resource):
|
||||
)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
@ -118,7 +120,7 @@ class ChatMessageTextApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("Failed to handle post request to ChatMessageTextApi")
|
||||
logger.exception("Failed to handle post request to ChatMessageTextApi")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -160,7 +162,7 @@ class TextModesApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("Failed to handle get request to TextModesApi")
|
||||
logger.exception("Failed to handle get request to TextModesApi")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
||||
@ -34,6 +34,8 @@ from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# define completion message api for user
|
||||
class CompletionMessageApi(Resource):
|
||||
@ -67,7 +69,7 @@ class CompletionMessageApi(Resource):
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -80,7 +82,7 @@ class CompletionMessageApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -134,7 +136,7 @@ class ChatMessageApi(Resource):
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -149,7 +151,7 @@ class ChatMessageApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
from controllers.console import api
|
||||
@ -33,6 +34,8 @@ from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatMessageListApi(Resource):
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
@ -92,21 +95,22 @@ class ChatMessageListApi(Resource):
|
||||
.all()
|
||||
)
|
||||
|
||||
has_more = False
|
||||
# Initialize has_more based on whether we have a full page
|
||||
if len(history_messages) == args["limit"]:
|
||||
current_page_first_message = history_messages[-1]
|
||||
rest_count = (
|
||||
db.session.query(Message)
|
||||
.where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id,
|
||||
# Check if there are more messages before the current page
|
||||
has_more = db.session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
else:
|
||||
# If we don't have a full page, there are no more messages
|
||||
has_more = False
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
@ -126,7 +130,7 @@ class MessageFeedbackApi(Resource):
|
||||
|
||||
message_id = str(args["message_id"])
|
||||
|
||||
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
@ -215,7 +219,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
|
||||
@ -72,6 +72,7 @@ class DraftWorkflowApi(Resource):
|
||||
Get draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -94,6 +95,7 @@ class DraftWorkflowApi(Resource):
|
||||
Sync draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -171,6 +173,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
Run draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -205,7 +208,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -218,13 +221,12 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
@ -242,7 +244,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -256,11 +258,10 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
@ -279,7 +280,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -292,12 +293,12 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
@ -316,7 +317,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -329,12 +330,12 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
@ -353,7 +354,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -366,12 +367,12 @@ class DraftWorkflowRunApi(Resource):
|
||||
"""
|
||||
Run draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
@ -405,6 +406,9 @@ class WorkflowTaskStopApi(Resource):
|
||||
"""
|
||||
Stop workflow task
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
@ -424,12 +428,12 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
@ -472,6 +476,9 @@ class PublishedWorkflowApi(Resource):
|
||||
"""
|
||||
Get published workflow
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
@ -491,13 +498,12 @@ class PublishedWorkflowApi(Resource):
|
||||
"""
|
||||
Publish workflow
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
@ -541,6 +547,9 @@ class DefaultBlockConfigsApi(Resource):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
@ -559,13 +568,12 @@ class DefaultBlockConfigApi(Resource):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("q", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
@ -595,13 +603,12 @@ class ConvertToWorkflowApi(Resource):
|
||||
Convert expert mode of chatbot app to workflow mode
|
||||
Convert Completion App to Workflow App
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
if request.data:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
@ -645,6 +652,9 @@ class PublishedAllWorkflowApi(Resource):
|
||||
"""
|
||||
Get published workflows
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -693,13 +703,12 @@ class WorkflowByIdApi(Resource):
|
||||
"""
|
||||
Update workflow attributes
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
@ -750,13 +759,12 @@ class WorkflowByIdApi(Resource):
|
||||
"""
|
||||
Delete workflow
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
|
||||
@ -21,6 +21,7 @@ from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode, db
|
||||
from models.account import Account
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -135,6 +136,7 @@ def _api_prerequisite(f):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def wrapper(*args, **kwargs):
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
@ -6,9 +6,11 @@ from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models import App, AppMode
|
||||
from models.account import Account
|
||||
|
||||
|
||||
def _load_app_model(app_id: str) -> Optional[App]:
|
||||
assert isinstance(current_user, Account)
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
|
||||
@ -13,6 +13,8 @@ from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
@ -80,7 +82,7 @@ class OAuthDataSourceBinding(Resource):
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
return {"error": "OAuth data source process failed"}, 400
|
||||
@ -103,7 +105,7 @@ class OAuthDataSourceSync(Resource):
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
return {"error": "OAuth data source process failed"}, 400
|
||||
|
||||
@ -55,6 +55,12 @@ class EmailOrPasswordMismatchError(BaseHTTPException):
|
||||
code = 400
|
||||
|
||||
|
||||
class AuthenticationFailedError(BaseHTTPException):
|
||||
error_code = "authentication_failed"
|
||||
description = "Invalid email or password."
|
||||
code = 401
|
||||
|
||||
|
||||
class EmailPasswordLoginLimitError(BaseHTTPException):
|
||||
error_code = "email_code_login_limit"
|
||||
description = "Too many incorrect password attempts. Please try again later."
|
||||
|
||||
@ -9,8 +9,8 @@ from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
EmailOrPasswordMismatchError,
|
||||
EmailPasswordLoginLimitError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
@ -79,7 +79,7 @@ class LoginApi(Resource):
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
AccountService.add_login_error_rate_limit(args["email"])
|
||||
raise EmailOrPasswordMismatchError()
|
||||
raise AuthenticationFailedError()
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
@ -132,6 +132,7 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
@ -221,7 +222,7 @@ class EmailCodeLoginApi(Resource):
|
||||
email=user_email, name=user_email, interface_language=languages[0]
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
return NotAllowedCreateWorkspace()
|
||||
raise NotAllowedCreateWorkspace()
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
except WorkspacesLimitExceededError:
|
||||
|
||||
@ -24,6 +24,8 @@ from services.feature_service import FeatureService
|
||||
|
||||
from .. import api
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
@ -80,7 +82,7 @@ class OAuthCallback(Resource):
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_text = e.response.text if e.response else str(e)
|
||||
logging.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
||||
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
||||
|
||||
187
api/controllers/console/auth/oauth_server.py
Normal file
187
api/controllers/console/auth/oauth_server.py
Normal file
@ -0,0 +1,187 @@
|
||||
from functools import wraps
|
||||
from typing import cast
|
||||
|
||||
import flask_login
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||
|
||||
from .. import api
|
||||
|
||||
|
||||
def oauth_server_client_id_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
client_id = parsed_args.get("client_id")
|
||||
if not client_id:
|
||||
raise BadRequest("client_id is required")
|
||||
|
||||
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
|
||||
if not oauth_provider_app:
|
||||
raise NotFound("client_id is invalid")
|
||||
|
||||
kwargs["oauth_provider_app"] = oauth_provider_app
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def oauth_server_access_token_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
oauth_provider_app = kwargs.get("oauth_provider_app")
|
||||
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
raise BadRequest("Invalid oauth_provider_app")
|
||||
|
||||
authorization_header = request.headers.get("Authorization")
|
||||
if not authorization_header:
|
||||
raise BadRequest("Authorization header is required")
|
||||
|
||||
parts = authorization_header.strip().split(" ")
|
||||
if len(parts) != 2:
|
||||
raise BadRequest("Invalid Authorization header format")
|
||||
|
||||
token_type = parts[0].strip()
|
||||
if token_type.lower() != "bearer":
|
||||
raise BadRequest("token_type is invalid")
|
||||
|
||||
access_token = parts[1].strip()
|
||||
if not access_token:
|
||||
raise BadRequest("access_token is required")
|
||||
|
||||
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
|
||||
if not account:
|
||||
raise BadRequest("access_token or client_id is invalid")
|
||||
|
||||
kwargs["account"] = account
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
class OAuthServerAppApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("redirect_uri", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
redirect_uri = parsed_args.get("redirect_uri")
|
||||
|
||||
# check if redirect_uri is valid
|
||||
if redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"app_icon": oauth_provider_app.app_icon,
|
||||
"app_label": oauth_provider_app.app_label,
|
||||
"scope": oauth_provider_app.scope,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthServerUserAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
user_account_id = account.id
|
||||
|
||||
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"code": code,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthServerUserTokenApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("grant_type", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=False, location="json")
|
||||
parser.add_argument("client_secret", type=str, required=False, location="json")
|
||||
parser.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||
parser.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
try:
|
||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
||||
except ValueError:
|
||||
raise BadRequest("invalid grant_type")
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not parsed_args["code"]:
|
||||
raise BadRequest("code is required")
|
||||
|
||||
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
|
||||
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
||||
if not parsed_args["refresh_token"]:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthServerUserAccountApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
@oauth_server_access_token_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp, account: Account):
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"name": account.name,
|
||||
"email": account.email,
|
||||
"avatar": account.avatar,
|
||||
"interface_language": account.interface_language,
|
||||
"timezone": account.timezone,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(OAuthServerAppApi, "/oauth/provider")
|
||||
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
|
||||
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
|
||||
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")
|
||||
@ -553,7 +553,7 @@ class DatasetIndexingStatusApi(Resource):
|
||||
}
|
||||
documents_status.append(marshal(document_dict, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
return data
|
||||
return data, 200
|
||||
|
||||
|
||||
class DatasetApiKeyApi(Resource):
|
||||
@ -660,6 +660,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
| VectorType.PINECONE
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
@ -711,6 +712,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
| VectorType.PINECONE
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
|
||||
@ -54,6 +54,8 @@ from models import Dataset, DatasetProcessRule, Document, DocumentSegment, Uploa
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
@ -468,25 +470,11 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
info_list = []
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
data_source_info = document.data_source_info_dict
|
||||
# format document files info
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
info_list.append(file_id)
|
||||
# format document notion info
|
||||
elif (
|
||||
data_source_info and "notion_workspace_id" in data_source_info and "notion_page_id" in data_source_info
|
||||
):
|
||||
pages = []
|
||||
page = {"page_id": data_source_info["notion_page_id"], "type": data_source_info["type"]}
|
||||
pages.append(page)
|
||||
notion_info = {"workspace_id": data_source_info["notion_workspace_id"], "pages": pages}
|
||||
info_list.append(notion_info)
|
||||
|
||||
if document.data_source_type == "upload_file":
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
@ -966,7 +954,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception:
|
||||
logging.exception("Failed to retry document, document id: %s", document_id)
|
||||
logger.exception("Failed to retry document, document id: %s", document_id)
|
||||
continue
|
||||
# retry document
|
||||
DocumentService.retry_document(dataset_id, retry_documents)
|
||||
|
||||
@ -23,6 +23,8 @@ from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
@ -81,5 +83,5 @@ class DatasetsHitTestingBase:
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
logging.exception("Hit testing failed.")
|
||||
logger.exception("Hit testing failed.")
|
||||
raise InternalServerError(str(e))
|
||||
|
||||
@ -1,62 +0,0 @@
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
class UploadFileApi(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""Get upload file."""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check upload file
|
||||
if document.data_source_type != "upload_file":
|
||||
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
else:
|
||||
raise ValueError("Upload file id not found in document data source info.")
|
||||
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": url,
|
||||
"download_url": f"{url}&as_attachment=true",
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at.timestamp(),
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")
|
||||
@ -26,6 +26,8 @@ from services.errors.audio import (
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatAudioApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
@ -38,7 +40,7 @@ class ChatAudioApi(InstalledAppResource):
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
@ -59,7 +61,7 @@ class ChatAudioApi(InstalledAppResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -83,7 +85,7 @@ class ChatTextApi(InstalledAppResource):
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
@ -104,5 +106,5 @@ class ChatTextApi(InstalledAppResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
@ -32,6 +32,8 @@ from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# define completion api for user
|
||||
class CompletionApi(InstalledAppResource):
|
||||
@ -65,7 +67,7 @@ class CompletionApi(InstalledAppResource):
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -78,7 +80,7 @@ class CompletionApi(InstalledAppResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -125,7 +127,7 @@ class ChatApi(InstalledAppResource):
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -140,7 +142,7 @@ class ChatApi(InstalledAppResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
||||
@ -35,6 +35,8 @@ from services.errors.message import (
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListApi(InstalledAppResource):
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
@ -126,7 +128,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -158,7 +160,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
|
||||
@ -43,7 +43,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert current_user is not None
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
@ -63,7 +63,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -76,6 +76,7 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
|
||||
@ -9,6 +9,8 @@ from configs import dify_config
|
||||
|
||||
from . import api
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VersionApi(Resource):
|
||||
def get(self):
|
||||
@ -34,7 +36,7 @@ class VersionApi(Resource):
|
||||
try:
|
||||
response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10))
|
||||
except Exception as error:
|
||||
logging.warning("Check update version error: %s.", str(error))
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
result["version"] = args.get("current_version")
|
||||
return result
|
||||
|
||||
@ -55,7 +57,7 @@ def _has_new_version(*, latest_version: str, current_version: str) -> bool:
|
||||
# Compare versions
|
||||
return latest > current
|
||||
except version.InvalidVersion:
|
||||
logging.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version)
|
||||
logger.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version)
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import TenantAccountRole
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
@ -15,10 +15,12 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
assert isinstance(current_user, Account)
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
assert tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
@ -64,10 +66,12 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str, config_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
assert tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
@ -54,7 +54,7 @@ class MemberInviteEmailApi(Resource):
|
||||
@cloud_edition_billing_resource_check("members")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("emails", type=str, required=True, location="json", action="append")
|
||||
parser.add_argument("emails", type=list, required=True, location="json")
|
||||
parser.add_argument("role", type=str, required=True, default="admin", location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -10,6 +10,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import StrLen, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
@ -45,12 +46,109 @@ class ModelProviderCredentialApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
# if credential_id is not provided, return current used credential
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider)
|
||||
credentials = model_provider_service.get_provider_credential(
|
||||
tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id")
|
||||
)
|
||||
|
||||
return {"credentials": credentials}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.create_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.update_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_id=args["credential_id"],
|
||||
credential_name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ModelProviderCredentialSwitchApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
service = ModelProviderService()
|
||||
service.switch_active_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ModelProviderValidateApi(Resource):
|
||||
@setup_required
|
||||
@ -69,7 +167,7 @@ class ModelProviderValidateApi(Resource):
|
||||
error = ""
|
||||
|
||||
try:
|
||||
model_provider_service.provider_credentials_validate(
|
||||
model_provider_service.validate_provider_credentials(
|
||||
tenant_id=tenant_id, provider=provider, credentials=args["credentials"]
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
@ -84,42 +182,6 @@ class ModelProviderValidateApi(Resource):
|
||||
return response
|
||||
|
||||
|
||||
class ModelProviderApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.save_provider_credentials(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"]
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ModelProviderIconApi(Resource):
|
||||
"""
|
||||
Get model provider icon
|
||||
@ -187,8 +249,10 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
|
||||
|
||||
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
|
||||
api.add_resource(
|
||||
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
|
||||
)
|
||||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>")
|
||||
|
||||
api.add_resource(
|
||||
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
|
||||
|
||||
@ -9,10 +9,13 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import StrLen, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultModelApi(Resource):
|
||||
@setup_required
|
||||
@ -72,7 +75,7 @@ class DefaultModelApi(Resource):
|
||||
model=model_setting["model"],
|
||||
)
|
||||
except Exception as ex:
|
||||
logging.exception(
|
||||
logger.exception(
|
||||
"Failed to update default model, model type: %s, model: %s",
|
||||
model_setting["model_type"],
|
||||
model_setting.get("model"),
|
||||
@ -98,6 +101,7 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
# To save the model's load balance configs
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
@ -113,22 +117,26 @@ class ModelProviderModelApi(Resource):
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.get("config_from", "") == "custom-model":
|
||||
if not args.get("credential_id"):
|
||||
raise ValueError("credential_id is required when configuring a custom-model")
|
||||
service = ModelProviderService()
|
||||
service.switch_active_custom_model_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
|
||||
if (
|
||||
"load_balancing" in args
|
||||
and args["load_balancing"]
|
||||
and "enabled" in args["load_balancing"]
|
||||
and args["load_balancing"]["enabled"]
|
||||
):
|
||||
if "configs" not in args["load_balancing"]:
|
||||
raise ValueError("invalid load balancing configs")
|
||||
|
||||
if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]:
|
||||
# save load balancing configs
|
||||
model_load_balancing_service.update_load_balancing_configs(
|
||||
tenant_id=tenant_id,
|
||||
@ -136,37 +144,17 @@ class ModelProviderModelApi(Resource):
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
configs=args["load_balancing"]["configs"],
|
||||
config_from=args.get("config_from", ""),
|
||||
)
|
||||
|
||||
# enable load balancing
|
||||
model_load_balancing_service.enable_model_load_balancing(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
)
|
||||
else:
|
||||
# disable load balancing
|
||||
model_load_balancing_service.disable_model_load_balancing(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
)
|
||||
|
||||
if args.get("config_from", "") != "predefined-model":
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.save_model_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
logging.exception(
|
||||
"Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
|
||||
tenant_id,
|
||||
args.get("model"),
|
||||
args.get("model_type"),
|
||||
)
|
||||
raise ValueError(str(ex))
|
||||
if args.get("load_balancing", {}).get("enabled"):
|
||||
model_load_balancing_service.enable_model_load_balancing(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
)
|
||||
else:
|
||||
model_load_balancing_service.disable_model_load_balancing(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -192,7 +180,7 @@ class ModelProviderModelApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model_credentials(
|
||||
model_provider_service.remove_model(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
)
|
||||
|
||||
@ -216,11 +204,17 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
parser.add_argument("config_from", type=str, required=False, nullable=True, location="args")
|
||||
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
credentials = model_provider_service.get_model_credentials(
|
||||
tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"]
|
||||
current_credential = model_provider_service.get_model_credential(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args.get("credential_id"),
|
||||
)
|
||||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
@ -228,10 +222,173 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
)
|
||||
|
||||
return {
|
||||
"credentials": credentials,
|
||||
"load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs},
|
||||
}
|
||||
if args.get("config_from", "") == "predefined-model":
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider
|
||||
)
|
||||
else:
|
||||
model_type = ModelType.value_of(args["model_type"]).to_origin_model_type()
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"]
|
||||
)
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"credentials": current_credential.get("credentials") if current_credential else {},
|
||||
"current_credential_id": current_credential.get("current_credential_id")
|
||||
if current_credential
|
||||
else None,
|
||||
"current_credential_name": current_credential.get("current_credential_name")
|
||||
if current_credential
|
||||
else None,
|
||||
"load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs},
|
||||
"available_credentials": available_credentials,
|
||||
}
|
||||
)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.create_model_credential(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
credential_name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
logger.exception(
|
||||
"Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
|
||||
tenant_id,
|
||||
args.get("model"),
|
||||
args.get("model_type"),
|
||||
)
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.update_model_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credentials=args["credentials"],
|
||||
credential_id=args["credential_id"],
|
||||
credential_name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
service = ModelProviderService()
|
||||
service.add_model_credential_to_model_list(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ModelProviderModelEnableApi(Resource):
|
||||
@ -314,7 +471,7 @@ class ModelProviderModelValidateApi(Resource):
|
||||
error = ""
|
||||
|
||||
try:
|
||||
model_provider_service.model_credentials_validate(
|
||||
model_provider_service.validate_model_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
@ -379,6 +536,10 @@ api.add_resource(
|
||||
api.add_resource(
|
||||
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelCredentialSwitchApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/credentials/switch",
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
|
||||
)
|
||||
|
||||
@ -95,7 +95,6 @@ class ToolBuiltinProviderInfoApi(Resource):
|
||||
def get(self, provider):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||
|
||||
@ -31,6 +31,9 @@ from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
provider_fields = {
|
||||
"provider_name": fields.String,
|
||||
"provider_type": fields.String,
|
||||
@ -120,7 +123,7 @@ class TenantApi(Resource):
|
||||
@marshal_with(tenant_fields)
|
||||
def get(self):
|
||||
if request.path == "/info":
|
||||
logging.warning("Deprecated URL /info was used.")
|
||||
logger.warning("Deprecated URL /info was used.")
|
||||
|
||||
tenant = current_user.current_tenant
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ api = ExternalApi(
|
||||
doc="/docs", # Enable Swagger UI at /files/docs
|
||||
)
|
||||
|
||||
files_ns = Namespace("files", description="File operations")
|
||||
files_ns = Namespace("files", description="File operations", path="/")
|
||||
|
||||
from . import image_preview, tool_files, upload
|
||||
|
||||
|
||||
@ -1,10 +1,23 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="Inner API",
|
||||
description="Internal APIs for enterprise features, billing, and plugin communication",
|
||||
doc="/docs", # Enable Swagger UI at /inner/api/docs
|
||||
)
|
||||
|
||||
# Create namespace
|
||||
inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/")
|
||||
|
||||
from . import mail
|
||||
from .plugin import plugin
|
||||
from .workspace import workspace
|
||||
|
||||
api.add_namespace(inner_api_ns)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only
|
||||
from tasks.mail_inner_task import send_inner_email_task
|
||||
|
||||
@ -26,13 +26,45 @@ class BaseMail(Resource):
|
||||
return {"message": "success"}, 200
|
||||
|
||||
|
||||
@inner_api_ns.route("/enterprise/mail")
|
||||
class EnterpriseMail(BaseMail):
|
||||
method_decorators = [setup_required, enterprise_inner_api_only]
|
||||
|
||||
@inner_api_ns.doc("send_enterprise_mail")
|
||||
@inner_api_ns.doc(description="Send internal email for enterprise features")
|
||||
@inner_api_ns.expect(_mail_parser)
|
||||
@inner_api_ns.doc(
|
||||
responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"}
|
||||
)
|
||||
def post(self):
|
||||
"""Send internal email for enterprise features.
|
||||
|
||||
This endpoint allows sending internal emails for enterprise-specific
|
||||
notifications and communications.
|
||||
|
||||
Returns:
|
||||
dict: Success message with status code 200
|
||||
"""
|
||||
return super().post()
|
||||
|
||||
|
||||
@inner_api_ns.route("/billing/mail")
|
||||
class BillingMail(BaseMail):
|
||||
method_decorators = [setup_required, billing_inner_api_only]
|
||||
|
||||
@inner_api_ns.doc("send_billing_mail")
|
||||
@inner_api_ns.doc(description="Send internal email for billing notifications")
|
||||
@inner_api_ns.expect(_mail_parser)
|
||||
@inner_api_ns.doc(
|
||||
responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"}
|
||||
)
|
||||
def post(self):
|
||||
"""Send internal email for billing notifications.
|
||||
|
||||
api.add_resource(EnterpriseMail, "/enterprise/mail")
|
||||
api.add_resource(BillingMail, "/billing/mail")
|
||||
This endpoint allows sending internal emails for billing-related
|
||||
notifications and alerts.
|
||||
|
||||
Returns:
|
||||
dict: Success message with status code 200
|
||||
"""
|
||||
return super().post()
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data
|
||||
from controllers.inner_api.wraps import plugin_inner_api_only
|
||||
from core.file.helpers import get_signed_file_url_for_plugin
|
||||
@ -35,11 +35,21 @@ from models.account import Account, Tenant
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/llm")
|
||||
class PluginInvokeLLMApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeLLM)
|
||||
@inner_api_ns.doc("plugin_invoke_llm")
|
||||
@inner_api_ns.doc(description="Invoke LLM models through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "LLM invocation successful (streaming response)",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
|
||||
@ -48,11 +58,21 @@ class PluginInvokeLLMApi(Resource):
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/llm/structured-output")
|
||||
class PluginInvokeLLMWithStructuredOutputApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
|
||||
@inner_api_ns.doc("plugin_invoke_llm_structured")
|
||||
@inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "LLM structured output invocation successful (streaming response)",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLMWithStructuredOutput):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_llm_with_structured_output(
|
||||
@ -63,11 +83,21 @@ class PluginInvokeLLMWithStructuredOutputApi(Resource):
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/text-embedding")
|
||||
class PluginInvokeTextEmbeddingApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTextEmbedding)
|
||||
@inner_api_ns.doc("plugin_invoke_text_embedding")
|
||||
@inner_api_ns.doc(description="Invoke text embedding models through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Text embedding successful",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
@ -83,11 +113,17 @@ class PluginInvokeTextEmbeddingApi(Resource):
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/rerank")
|
||||
class PluginInvokeRerankApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeRerank)
|
||||
@inner_api_ns.doc("plugin_invoke_rerank")
|
||||
@inner_api_ns.doc(description="Invoke rerank models through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={200: "Rerank successful", 401: "Unauthorized - invalid API key", 404: "Service not available"}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeRerank):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
@ -103,11 +139,21 @@ class PluginInvokeRerankApi(Resource):
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/tts")
|
||||
class PluginInvokeTTSApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTTS)
|
||||
@inner_api_ns.doc("plugin_invoke_tts")
|
||||
@inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "TTS invocation successful (streaming response)",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTTS):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_tts(
|
||||
@ -120,11 +166,17 @@ class PluginInvokeTTSApi(Resource):
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/speech2text")
|
||||
class PluginInvokeSpeech2TextApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeSpeech2Text)
|
||||
@inner_api_ns.doc("plugin_invoke_speech2text")
|
||||
@inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={200: "Speech2Text successful", 401: "Unauthorized - invalid API key", 404: "Service not available"}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
@ -140,11 +192,17 @@ class PluginInvokeSpeech2TextApi(Resource):
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/moderation")
|
||||
class PluginInvokeModerationApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeModeration)
|
||||
@inner_api_ns.doc("plugin_invoke_moderation")
|
||||
@inner_api_ns.doc(description="Invoke moderation models through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={200: "Moderation successful", 401: "Unauthorized - invalid API key", 404: "Service not available"}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeModeration):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
@ -160,11 +218,21 @@ class PluginInvokeModerationApi(Resource):
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/tool")
|
||||
class PluginInvokeToolApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTool)
|
||||
@inner_api_ns.doc("plugin_invoke_tool")
|
||||
@inner_api_ns.doc(description="Invoke tools through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Tool invocation successful (streaming response)",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool):
|
||||
def generator():
|
||||
return PluginToolBackwardsInvocation.convert_to_event_stream(
|
||||
@ -182,11 +250,21 @@ class PluginInvokeToolApi(Resource):
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/parameter-extractor")
|
||||
class PluginInvokeParameterExtractorNodeApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
|
||||
@inner_api_ns.doc("plugin_invoke_parameter_extractor")
|
||||
@inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Parameter extraction successful",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
@ -205,11 +283,21 @@ class PluginInvokeParameterExtractorNodeApi(Resource):
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/question-classifier")
|
||||
class PluginInvokeQuestionClassifierNodeApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
|
||||
@inner_api_ns.doc("plugin_invoke_question_classifier")
|
||||
@inner_api_ns.doc(description="Invoke question classifier node through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Question classification successful",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
@ -228,11 +316,21 @@ class PluginInvokeQuestionClassifierNodeApi(Resource):
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/app")
|
||||
class PluginInvokeAppApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeApp)
|
||||
@inner_api_ns.doc("plugin_invoke_app")
|
||||
@inner_api_ns.doc(description="Invoke application through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "App invocation successful (streaming response)",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp):
|
||||
response = PluginAppBackwardsInvocation.invoke_app(
|
||||
app_id=payload.app_id,
|
||||
@ -248,11 +346,21 @@ class PluginInvokeAppApi(Resource):
|
||||
return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/encrypt")
|
||||
class PluginInvokeEncryptApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeEncrypt)
|
||||
@inner_api_ns.doc("plugin_invoke_encrypt")
|
||||
@inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Encryption/decryption successful",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeEncrypt):
|
||||
"""
|
||||
encrypt or decrypt data
|
||||
@ -265,11 +373,21 @@ class PluginInvokeEncryptApi(Resource):
|
||||
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/summary")
|
||||
class PluginInvokeSummaryApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeSummary)
|
||||
@inner_api_ns.doc("plugin_invoke_summary")
|
||||
@inner_api_ns.doc(description="Invoke summary functionality through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Summary generation successful",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSummary):
|
||||
try:
|
||||
return BaseBackwardsInvocationResponse(
|
||||
@ -285,40 +403,43 @@ class PluginInvokeSummaryApi(Resource):
|
||||
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
|
||||
|
||||
|
||||
@inner_api_ns.route("/upload/file/request")
|
||||
class PluginUploadFileRequestApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestRequestUploadFile)
|
||||
@inner_api_ns.doc("plugin_upload_file_request")
|
||||
@inner_api_ns.doc(description="Request signed URL for file upload through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Signed URL generated successfully",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
|
||||
# generate signed url
|
||||
url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id)
|
||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||
|
||||
|
||||
@inner_api_ns.route("/fetch/app/info")
|
||||
class PluginFetchAppInfoApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestFetchAppInfo)
|
||||
@inner_api_ns.doc("plugin_fetch_app_info")
|
||||
@inner_api_ns.doc(description="Fetch application information through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "App information retrieved successfully",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestFetchAppInfo):
|
||||
return BaseBackwardsInvocationResponse(
|
||||
data=PluginAppBackwardsInvocation.fetch_app_info(payload.app_id, tenant_model.id)
|
||||
).model_dump()
|
||||
|
||||
|
||||
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
|
||||
api.add_resource(PluginInvokeLLMWithStructuredOutputApi, "/invoke/llm/structured-output")
|
||||
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
|
||||
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
|
||||
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
|
||||
api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text")
|
||||
api.add_resource(PluginInvokeModerationApi, "/invoke/moderation")
|
||||
api.add_resource(PluginInvokeToolApi, "/invoke/tool")
|
||||
api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor")
|
||||
api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier")
|
||||
api.add_resource(PluginInvokeAppApi, "/invoke/app")
|
||||
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
|
||||
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
|
||||
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
|
||||
api.add_resource(PluginFetchAppInfoApi, "/fetch/app/info")
|
||||
|
||||
@ -3,7 +3,7 @@ import json
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
@ -11,9 +11,19 @@ from models.account import Account
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
||||
@inner_api_ns.route("/enterprise/workspace")
|
||||
class EnterpriseWorkspace(Resource):
|
||||
@setup_required
|
||||
@enterprise_inner_api_only
|
||||
@inner_api_ns.doc("create_enterprise_workspace")
|
||||
@inner_api_ns.doc(description="Create a new enterprise workspace with owner assignment")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Workspace created successfully",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Owner account not found or service not available",
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
@ -44,9 +54,19 @@ class EnterpriseWorkspace(Resource):
|
||||
}
|
||||
|
||||
|
||||
@inner_api_ns.route("/enterprise/workspace/ownerless")
|
||||
class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
||||
@setup_required
|
||||
@enterprise_inner_api_only
|
||||
@inner_api_ns.doc("create_enterprise_workspace_ownerless")
|
||||
@inner_api_ns.doc(description="Create a new enterprise workspace without initial owner assignment")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Workspace created successfully",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
@ -71,7 +91,3 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
||||
"message": "enterprise workspace created.",
|
||||
"tenant": resp,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(EnterpriseWorkspace, "/enterprise/workspace")
|
||||
api.add_resource(EnterpriseWorkspaceNoOwnerEmail, "/enterprise/workspace/ownerless")
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
from base64 import b64encode
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
@ -10,9 +14,9 @@ from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def billing_inner_api_only(view):
|
||||
def billing_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
@ -26,9 +30,9 @@ def billing_inner_api_only(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_inner_api_only(view):
|
||||
def enterprise_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
@ -78,9 +82,9 @@ def enterprise_inner_api_user_auth(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def plugin_inner_api_only(view):
|
||||
def plugin_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.PLUGIN_DAEMON_KEY:
|
||||
abort(404)
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ api = ExternalApi(
|
||||
doc="/docs", # Enable Swagger UI at /mcp/docs
|
||||
)
|
||||
|
||||
mcp_ns = Namespace("mcp", description="MCP operations")
|
||||
mcp_ns = Namespace("mcp", description="MCP operations", path="/")
|
||||
|
||||
from . import mcp
|
||||
|
||||
|
||||
@ -1,18 +1,27 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, reqparse
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import mcp_ns
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp import types
|
||||
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
|
||||
from core.mcp.types import ClientNotification, ClientRequest
|
||||
from core.mcp.utils import create_mcp_error_response
|
||||
from core.mcp import types as mcp_types
|
||||
from core.mcp.server.streamable_http import handle_mcp_request
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMCPServer, AppMode
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
|
||||
|
||||
class MCPRequestError(Exception):
|
||||
"""Custom exception for MCP request processing errors"""
|
||||
|
||||
def __init__(self, error_code: int, message: str):
|
||||
self.error_code = error_code
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def int_or_str(value):
|
||||
@ -63,77 +72,173 @@ class MCPAppApi(Resource):
|
||||
Raises:
|
||||
ValidationError: Invalid request format or parameters
|
||||
"""
|
||||
# Parse and validate all arguments
|
||||
args = mcp_request_parser.parse_args()
|
||||
|
||||
request_id: Optional[Union[int, str]] = args.get("id")
|
||||
mcp_request = self._parse_mcp_request(args)
|
||||
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
||||
if not server:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get MCP server and app
|
||||
mcp_server, app = self._get_mcp_server_and_app(server_code, session)
|
||||
self._validate_server_status(mcp_server)
|
||||
|
||||
if server.status != AppMCPServerStatus.ACTIVE:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
|
||||
)
|
||||
# Get user input form
|
||||
user_input_form = self._get_user_input_form(app)
|
||||
|
||||
app = db.session.query(App).where(App.id == server.app_id).first()
|
||||
# Handle notification vs request differently
|
||||
return self._process_mcp_message(mcp_request, request_id, app, mcp_server, user_input_form, session)
|
||||
|
||||
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
|
||||
"""Get and validate MCP server and app in one query session"""
|
||||
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
||||
if not mcp_server:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
|
||||
|
||||
app = session.query(App).where(App.id == mcp_server.app_id).first()
|
||||
if not app:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
|
||||
)
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
|
||||
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
return mcp_server, app
|
||||
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
def _validate_server_status(self, mcp_server: AppMCPServer) -> None:
|
||||
"""Validate MCP server status"""
|
||||
if mcp_server.status != AppMCPServerStatus.ACTIVE:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")
|
||||
|
||||
def _process_mcp_message(
|
||||
self,
|
||||
mcp_request: mcp_types.ClientRequest | mcp_types.ClientNotification,
|
||||
request_id: Optional[Union[int, str]],
|
||||
app: App,
|
||||
mcp_server: AppMCPServer,
|
||||
user_input_form: list[VariableEntity],
|
||||
session: Session,
|
||||
) -> Response:
|
||||
"""Process MCP message (notification or request)"""
|
||||
if isinstance(mcp_request, mcp_types.ClientNotification):
|
||||
return self._handle_notification(mcp_request)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
return self._handle_request(mcp_request, request_id, app, mcp_server, user_input_form, session)
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
converted_user_input_form: list[VariableEntity] = []
|
||||
try:
|
||||
for item in user_input_form:
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
converted_user_input_form.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
)
|
||||
)
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
)
|
||||
def _handle_notification(self, mcp_request: mcp_types.ClientNotification) -> Response:
|
||||
"""Handle MCP notification"""
|
||||
# For notifications, only support init notification
|
||||
if mcp_request.root.method != "notifications/initialized":
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Invalid notification method")
|
||||
# Return HTTP 202 Accepted for notifications (no response body)
|
||||
return Response("", status=202, content_type="application/json")
|
||||
|
||||
def _handle_request(
|
||||
self,
|
||||
mcp_request: mcp_types.ClientRequest,
|
||||
request_id: Optional[Union[int, str]],
|
||||
app: App,
|
||||
mcp_server: AppMCPServer,
|
||||
user_input_form: list[VariableEntity],
|
||||
session: Session,
|
||||
) -> Response:
|
||||
"""Handle MCP request"""
|
||||
if request_id is None:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Request ID is required")
|
||||
|
||||
result = self._handle_mcp_request(app, mcp_server, mcp_request, user_input_form, session, request_id)
|
||||
if result is None:
|
||||
# This shouldn't happen for requests, but handle gracefully
|
||||
raise MCPRequestError(mcp_types.INTERNAL_ERROR, "No response generated for request")
|
||||
|
||||
return helper.compact_generate_response(result.model_dump(by_alias=True, mode="json", exclude_none=True))
|
||||
|
||||
def _get_user_input_form(self, app: App) -> list[VariableEntity]:
|
||||
"""Get and convert user input form"""
|
||||
# Get raw user input form based on app mode
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if not app.workflow:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
|
||||
raw_user_input_form = app.workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
if not app.app_model_config:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
|
||||
features_dict = app.app_model_config.to_dict()
|
||||
raw_user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
# Convert to VariableEntity objects
|
||||
try:
|
||||
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
|
||||
return self._convert_user_input_form(raw_user_input_form)
|
||||
except ValidationError as e:
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
|
||||
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
|
||||
"""Convert raw user input form to VariableEntity objects"""
|
||||
return [self._create_variable_entity(item) for item in raw_form]
|
||||
|
||||
def _create_variable_entity(self, item: dict) -> VariableEntity:
|
||||
"""Create a single VariableEntity from raw form item"""
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
|
||||
return VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
)
|
||||
|
||||
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
"""Parse and validate MCP request"""
|
||||
try:
|
||||
return mcp_types.ClientRequest.model_validate(args)
|
||||
except ValidationError:
|
||||
try:
|
||||
notification = ClientNotification.model_validate(args)
|
||||
request = notification
|
||||
return mcp_types.ClientNotification.model_validate(args)
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
)
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
|
||||
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
|
||||
response = mcp_server_handler.handle()
|
||||
return helper.compact_generate_response(response)
|
||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
|
||||
"""Get end user from existing session - optimized query"""
|
||||
return (
|
||||
session.query(EndUser)
|
||||
.where(EndUser.tenant_id == tenant_id)
|
||||
.where(EndUser.session_id == mcp_server_id)
|
||||
.where(EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
def _create_end_user(
|
||||
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
|
||||
) -> EndUser:
|
||||
"""Create end user in existing session"""
|
||||
end_user = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type="mcp",
|
||||
name=client_name,
|
||||
session_id=mcp_server_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
session.flush() # Use flush instead of commit to keep transaction open
|
||||
session.refresh(end_user)
|
||||
return end_user
|
||||
|
||||
def _handle_mcp_request(
|
||||
self,
|
||||
app: App,
|
||||
mcp_server: AppMCPServer,
|
||||
mcp_request: mcp_types.ClientRequest,
|
||||
user_input_form: list[VariableEntity],
|
||||
session: Session,
|
||||
request_id: Union[int, str],
|
||||
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
|
||||
"""Handle MCP request and return response"""
|
||||
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
|
||||
|
||||
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
|
||||
client_info = mcp_request.root.params.clientInfo
|
||||
client_name = f"{client_info.name}@{client_info.version}"
|
||||
# Commit the session before creating end user to avoid transaction conflicts
|
||||
session.commit()
|
||||
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
|
||||
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
|
||||
|
||||
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)
|
||||
|
||||
@ -13,7 +13,7 @@ api = ExternalApi(
|
||||
doc="/docs", # Enable Swagger UI at /v1/docs
|
||||
)
|
||||
|
||||
service_api_ns = Namespace("service_api", description="Service operations")
|
||||
service_api_ns = Namespace("service_api", description="Service operations", path="/")
|
||||
|
||||
from . import index
|
||||
from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow
|
||||
|
||||
@ -10,6 +10,7 @@ from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import annotation_fields, build_annotation_model
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
@ -163,6 +164,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
|
||||
def put(self, app_model: App, annotation_id):
|
||||
"""Update an existing annotation."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -185,6 +187,8 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@validate_app_token
|
||||
def delete(self, app_model: App, annotation_id):
|
||||
"""Delete an annotation."""
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
@ -29,6 +29,8 @@ from services.errors.audio import (
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@service_api_ns.route("/audio-to-text")
|
||||
class AudioApi(Resource):
|
||||
@ -57,7 +59,7 @@ class AudioApi(Resource):
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
@ -78,7 +80,7 @@ class AudioApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -121,7 +123,7 @@ class TextApi(Resource):
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
@ -142,5 +144,5 @@ class TextApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
@ -33,6 +33,9 @@ from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define parser for completion API
|
||||
completion_parser = reqparse.RequestParser()
|
||||
completion_parser.add_argument(
|
||||
@ -118,7 +121,7 @@ class CompletionApi(Resource):
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -131,7 +134,7 @@ class CompletionApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -209,7 +212,7 @@ class ChatApi(Resource):
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -224,7 +227,7 @@ class ChatApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
||||
@ -22,6 +22,9 @@ from services.errors.message import (
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define parsers for message APIs
|
||||
message_list_parser = reqparse.RequestParser()
|
||||
message_list_parser.add_argument(
|
||||
@ -216,7 +219,7 @@ class MessageSuggestedApi(Resource):
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise BadRequest("Suggested Questions Is Disabled.")
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"result": "success", "data": questions}
|
||||
|
||||
@ -174,7 +174,7 @@ class WorkflowRunApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -239,7 +239,7 @@ class WorkflowRunByIdApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
@ -213,7 +214,10 @@ class DatasetListApi(DatasetApiResource):
|
||||
)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
configurations = provider_manager.get_configurations(tenant_id=cid)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
|
||||
@ -266,6 +270,7 @@ class DatasetListApi(DatasetApiResource):
|
||||
)
|
||||
|
||||
try:
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=args["name"],
|
||||
@ -313,13 +318,12 @@ class DatasetApi(DatasetApiResource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
if data.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
configurations = provider_manager.get_configurations(tenant_id=cid)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
|
||||
@ -391,6 +395,7 @@ class DatasetApi(DatasetApiResource):
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||
@ -532,7 +537,10 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def get(self, _, dataset_id):
|
||||
"""Get all knowledge type tags."""
|
||||
tags = TagService.get_tags("knowledge", current_user.current_tenant_id)
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
tags = TagService.get_tags("knowledge", cid)
|
||||
|
||||
return tags, 200
|
||||
|
||||
@ -550,6 +558,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
@ -573,6 +582,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
@validate_dataset_token
|
||||
def patch(self, _, dataset_id):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
@ -599,6 +609,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def delete(self, _, dataset_id):
|
||||
"""Delete a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
args = tag_delete_parser.parse_args()
|
||||
@ -622,6 +633,7 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
@ -647,6 +659,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
@ -672,6 +685,8 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
|
||||
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
|
||||
response = {"data": tags_list, "total": len(tags)}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in # type: ignore
|
||||
from flask_login import user_logged_in
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
|
||||
@ -1,19 +1,20 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
from .files import FileApi
|
||||
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
bp = Blueprint("web", __name__, url_prefix="/api")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
# Files
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="Web API",
|
||||
description="Public APIs for web applications including file uploads, chat interactions, and app management",
|
||||
doc="/docs", # Enable Swagger UI at /api/docs
|
||||
)
|
||||
|
||||
# Remote files
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
# Create namespace
|
||||
web_ns = Namespace("web", description="Web application API operations", path="/")
|
||||
|
||||
from . import (
|
||||
app,
|
||||
@ -21,11 +22,15 @@ from . import (
|
||||
completion,
|
||||
conversation,
|
||||
feature,
|
||||
files,
|
||||
forgot_password,
|
||||
login,
|
||||
message,
|
||||
passport,
|
||||
remote_files,
|
||||
saved_message,
|
||||
site,
|
||||
workflow,
|
||||
)
|
||||
|
||||
api.add_namespace(web_ns)
|
||||
|
||||
@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import AppUnavailableError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
@ -16,10 +16,25 @@ from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@web_ns.route("/parameters")
|
||||
class AppParameterApi(WebApiResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@web_ns.doc("Get App Parameters")
|
||||
@web_ns.doc(description="Retrieve the parameters for a specific app.")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "App Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
@ -42,13 +57,42 @@ class AppParameterApi(WebApiResource):
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
|
||||
|
||||
@web_ns.route("/meta")
|
||||
class AppMeta(WebApiResource):
|
||||
@web_ns.doc("Get App Meta")
|
||||
@web_ns.doc(description="Retrieve the metadata for a specific app.")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "App Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Get app meta"""
|
||||
return AppService().get_app_meta(app_model)
|
||||
|
||||
|
||||
@web_ns.route("/webapp/access-mode")
|
||||
class AppAccessMode(Resource):
|
||||
@web_ns.doc("Get App Access Mode")
|
||||
@web_ns.doc(description="Retrieve the access mode for a web application (public or restricted).")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"appId": {"description": "Application ID", "type": "string", "required": False},
|
||||
"appCode": {"description": "Application code", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("appId", type=str, required=False, location="args")
|
||||
@ -72,7 +116,19 @@ class AppAccessMode(Resource):
|
||||
return {"accessMode": res.access_mode}
|
||||
|
||||
|
||||
@web_ns.route("/webapp/permission")
|
||||
class AppWebAuthPermission(Resource):
|
||||
@web_ns.doc("Check App Permission")
|
||||
@web_ns.doc(description="Check if user has permission to access a web application.")
|
||||
@web_ns.doc(params={"appId": {"description": "Application ID", "type": "string", "required": True}})
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self):
|
||||
user_id = "visitor"
|
||||
try:
|
||||
@ -92,7 +148,7 @@ class AppWebAuthPermission(Resource):
|
||||
except Unauthorized:
|
||||
raise
|
||||
except Exception:
|
||||
logging.exception("Unexpected error during auth verification")
|
||||
logger.exception("Unexpected error during auth verification")
|
||||
raise
|
||||
|
||||
features = FeatureService.get_system_features()
|
||||
@ -110,10 +166,3 @@ class AppWebAuthPermission(Resource):
|
||||
if WebAppAuthService.is_app_require_permission_check(app_id=app_id):
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
|
||||
return {"result": res}
|
||||
|
||||
|
||||
api.add_resource(AppParameterApi, "/parameters")
|
||||
api.add_resource(AppMeta, "/meta")
|
||||
# webapp auth apis
|
||||
api.add_resource(AppAccessMode, "/webapp/access-mode")
|
||||
api.add_resource(AppWebAuthPermission, "/webapp/permission")
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
@ -28,9 +29,30 @@ from services.errors.audio import (
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AudioApi(WebApiResource):
|
||||
audio_to_text_response_fields = {
|
||||
"text": fields.String,
|
||||
}
|
||||
|
||||
@marshal_with(audio_to_text_response_fields)
|
||||
@api.doc("Audio to Text")
|
||||
@api.doc(description="Convert audio file to text using speech-to-text service.")
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
413: "Audio file too large",
|
||||
415: "Unsupported audio type",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user):
|
||||
"""Convert audio to text"""
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
@ -38,7 +60,7 @@ class AudioApi(WebApiResource):
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
@ -59,14 +81,30 @@ class AudioApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("Failed to handle post request to AudioApi")
|
||||
logger.exception("Failed to handle post request to AudioApi")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
from flask_restx import reqparse
|
||||
text_to_audio_response_fields = {
|
||||
"audio_url": fields.String,
|
||||
"duration": fields.Float,
|
||||
}
|
||||
|
||||
@marshal_with(text_to_audio_response_fields)
|
||||
@api.doc("Text to Audio")
|
||||
@api.doc(description="Convert text to audio using text-to-speech service.")
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user):
|
||||
"""Convert text to audio"""
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
@ -84,7 +122,7 @@ class TextApi(WebApiResource):
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
@ -105,7 +143,7 @@ class TextApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("Failed to handle post request to TextApi")
|
||||
logger.exception("Failed to handle post request to TextApi")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
||||
@ -31,9 +31,37 @@ from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# define completion api for user
|
||||
class CompletionApi(WebApiResource):
|
||||
@api.doc("Create Completion Message")
|
||||
@api.doc(description="Create a completion message for text generation applications.")
|
||||
@api.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the completion", "type": "object", "required": True},
|
||||
"query": {"description": "Query text for completion", "type": "string", "required": False},
|
||||
"files": {"description": "Files to be processed", "type": "array", "required": False},
|
||||
"response_mode": {
|
||||
"description": "Response mode: blocking or streaming",
|
||||
"type": "string",
|
||||
"enum": ["blocking", "streaming"],
|
||||
"required": False,
|
||||
},
|
||||
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "App Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
@ -61,7 +89,7 @@ class CompletionApi(WebApiResource):
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -74,11 +102,24 @@ class CompletionApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class CompletionStopApi(WebApiResource):
|
||||
@api.doc("Stop Completion Message")
|
||||
@api.doc(description="Stop a running completion message task.")
|
||||
@api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}})
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Task Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user, task_id):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
@ -89,6 +130,34 @@ class CompletionStopApi(WebApiResource):
|
||||
|
||||
|
||||
class ChatApi(WebApiResource):
|
||||
@api.doc("Create Chat Message")
|
||||
@api.doc(description="Create a chat message for conversational applications.")
|
||||
@api.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the chat", "type": "object", "required": True},
|
||||
"query": {"description": "User query/message", "type": "string", "required": True},
|
||||
"files": {"description": "Files to be processed", "type": "array", "required": False},
|
||||
"response_mode": {
|
||||
"description": "Response mode: blocking or streaming",
|
||||
"type": "string",
|
||||
"enum": ["blocking", "streaming"],
|
||||
"required": False,
|
||||
},
|
||||
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": False},
|
||||
"parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False},
|
||||
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "App Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -119,7 +188,7 @@ class ChatApi(WebApiResource):
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -134,11 +203,24 @@ class ChatApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class ChatStopApi(WebApiResource):
|
||||
@api.doc("Stop Chat Message")
|
||||
@api.doc(description="Stop a running chat message task.")
|
||||
@api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}})
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Task Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user, task_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import marshal_with, reqparse
|
||||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -58,6 +58,11 @@ class ConversationListApi(WebApiResource):
|
||||
|
||||
|
||||
class ConversationApi(WebApiResource):
|
||||
delete_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@marshal_with(delete_response_fields)
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -94,6 +99,11 @@ class ConversationRenameApi(WebApiResource):
|
||||
|
||||
|
||||
class ConversationPinApi(WebApiResource):
|
||||
pin_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@marshal_with(pin_response_fields)
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -110,6 +120,11 @@ class ConversationPinApi(WebApiResource):
|
||||
|
||||
|
||||
class ConversationUnPinApi(WebApiResource):
|
||||
unpin_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@marshal_with(unpin_response_fields)
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
||||
@ -1,12 +1,21 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@web_ns.route("/system-features")
|
||||
class SystemFeatureApi(Resource):
|
||||
@web_ns.doc("get_system_features")
|
||||
@web_ns.doc(description="Get system feature flags and configuration")
|
||||
@web_ns.doc(responses={200: "System features retrieved successfully", 500: "Internal server error"})
|
||||
def get(self):
|
||||
"""Get system feature flags and configuration.
|
||||
|
||||
Returns the current system feature flags and configuration
|
||||
that control various functionalities across the platform.
|
||||
|
||||
Returns:
|
||||
dict: System feature configuration object
|
||||
"""
|
||||
return FeatureService.get_system_features().model_dump()
|
||||
|
||||
|
||||
api.add_resource(SystemFeatureApi, "/system-features")
|
||||
|
||||
@ -9,14 +9,50 @@ from controllers.common.errors import (
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.file_fields import file_fields
|
||||
from fields.file_fields import build_file_model
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@web_ns.route("/files/upload")
|
||||
class FileApi(WebApiResource):
|
||||
@marshal_with(file_fields)
|
||||
@web_ns.doc("upload_file")
|
||||
@web_ns.doc(description="Upload a file for use in web applications")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
201: "File uploaded successfully",
|
||||
400: "Bad request - invalid file or parameters",
|
||||
413: "File too large",
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
@marshal_with(build_file_model(web_ns))
|
||||
def post(self, app_model, end_user):
|
||||
"""Upload a file for use in web applications.
|
||||
|
||||
Accepts file uploads for use within web applications, supporting
|
||||
multiple file types with automatic validation and storage.
|
||||
|
||||
Args:
|
||||
app_model: The associated application model
|
||||
end_user: The end user uploading the file
|
||||
|
||||
Form Parameters:
|
||||
file: The file to upload (required)
|
||||
source: Optional source type (datasets or None)
|
||||
|
||||
Returns:
|
||||
dict: File information including ID, URL, and metadata
|
||||
int: HTTP status code 201 for success
|
||||
|
||||
Raises:
|
||||
NoFileUploadedError: No file provided in request
|
||||
TooManyFilesError: Multiple files provided (only one allowed)
|
||||
FilenameNotExistsError: File has no filename
|
||||
FileTooLargeError: File exceeds size limit
|
||||
UnsupportedFileTypeError: File type not supported
|
||||
"""
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
|
||||
@ -7,15 +7,16 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
EmailPasswordResetLimitError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.error import EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
@ -23,10 +24,21 @@ from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
@web_ns.route("/forgot-password")
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@web_ns.doc("send_forgot_password_email")
|
||||
@web_ns.doc(description="Send password reset email")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Password reset email sent successfully",
|
||||
400: "Bad request - invalid email format",
|
||||
404: "Account not found",
|
||||
429: "Too many requests - rate limit exceeded",
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
@ -46,17 +58,23 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
token = None
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@web_ns.route("/forgot-password/validity")
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@web_ns.doc("check_forgot_password_token")
|
||||
@web_ns.doc(description="Verify password reset token validity")
|
||||
@web_ns.doc(
|
||||
responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"}
|
||||
)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
@ -93,10 +111,21 @@ class ForgotPasswordCheckApi(Resource):
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
@web_ns.route("/forgot-password/resets")
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@web_ns.doc("reset_password")
|
||||
@web_ns.doc(description="Reset user password with verification token")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Password reset successfully",
|
||||
400: "Bad request - invalid parameters or password mismatch",
|
||||
401: "Invalid or expired token",
|
||||
404: "Account not found",
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
@ -131,7 +160,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -140,8 +169,3 @@ class ForgotPasswordResetApi(Resource):
|
||||
account.password = base64.b64encode(password_hashed).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
session.commit()
|
||||
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
||||
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
||||
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
|
||||
|
||||
@ -1,22 +1,38 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
from jwt import InvalidTokenError # type: ignore
|
||||
from jwt import InvalidTokenError
|
||||
|
||||
import services
|
||||
from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError
|
||||
from controllers.console.error import AccountBannedError, AccountNotFound
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
InvalidEmailError,
|
||||
)
|
||||
from controllers.console.error import AccountBannedError
|
||||
from controllers.console.wraps import only_edition_enterprise, setup_required
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from libs.helper import email
|
||||
from libs.password import valid_password
|
||||
from services.account_service import AccountService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
@web_ns.route("/login")
|
||||
class LoginApi(Resource):
|
||||
"""Resource for web app email/password login."""
|
||||
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
@web_ns.doc("web_app_login")
|
||||
@web_ns.doc(description="Authenticate user for web application access")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Authentication successful",
|
||||
400: "Bad request - invalid email or password format",
|
||||
401: "Authentication failed - email or password mismatch",
|
||||
403: "Account banned or login disabled",
|
||||
404: "Account not found",
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = reqparse.RequestParser()
|
||||
@ -29,9 +45,9 @@ class LoginApi(Resource):
|
||||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
raise EmailOrPasswordMismatchError()
|
||||
raise AuthenticationFailedError()
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
raise AccountNotFound()
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
return {"result": "success", "data": {"access_token": token}}
|
||||
@ -47,9 +63,19 @@ class LoginApi(Resource):
|
||||
# return {"result": "success"}
|
||||
|
||||
|
||||
@web_ns.route("/email-code-login")
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
@web_ns.doc("send_email_code_login")
|
||||
@web_ns.doc(description="Send email verification code for login")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Email code sent successfully",
|
||||
400: "Bad request - invalid email format",
|
||||
404: "Account not found",
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
@ -63,16 +89,27 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
|
||||
account = WebAppAuthService.get_user_through_email(args["email"])
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
token = WebAppAuthService.send_email_code_login_email(account=account, language=language)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@web_ns.route("/email-code-login/validity")
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
@web_ns.doc("verify_email_code_login")
|
||||
@web_ns.doc(description="Verify email code and complete login")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Email code verified and login successful",
|
||||
400: "Bad request - invalid code or token",
|
||||
401: "Invalid token or expired code",
|
||||
404: "Account not found",
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
@ -95,14 +132,8 @@ class EmailCodeLoginApi(Resource):
|
||||
WebAppAuthService.revoke_email_code_login_token(args["token"])
|
||||
account = WebAppAuthService.get_user_through_email(user_email)
|
||||
if not account:
|
||||
raise AccountNotFound()
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": {"access_token": token}}
|
||||
|
||||
|
||||
api.add_resource(LoginApi, "/login")
|
||||
# api.add_resource(LogoutApi, "/logout")
|
||||
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||
|
||||
@ -35,6 +35,8 @@ from services.errors.message import (
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListApi(WebApiResource):
|
||||
message_fields = {
|
||||
@ -83,6 +85,11 @@ class MessageListApi(WebApiResource):
|
||||
|
||||
|
||||
class MessageFeedbackApi(WebApiResource):
|
||||
feedback_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@marshal_with(feedback_response_fields)
|
||||
def post(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
@ -145,11 +152,16 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class MessageSuggestedQuestionApi(WebApiResource):
|
||||
suggested_questions_response_fields = {
|
||||
"data": fields.List(fields.String),
|
||||
}
|
||||
|
||||
@marshal_with(suggested_questions_response_fields)
|
||||
def get(self, app_model, end_user, message_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -176,7 +188,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
|
||||
@ -7,7 +7,7 @@ from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
@ -17,9 +17,19 @@ from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
|
||||
|
||||
|
||||
@web_ns.route("/passport")
|
||||
class PassportResource(Resource):
|
||||
"""Base resource for passport."""
|
||||
|
||||
@web_ns.doc("get_passport")
|
||||
@web_ns.doc(description="Get authentication passport for web application access")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Passport retrieved successfully",
|
||||
401: "Unauthorized - missing app code or invalid authentication",
|
||||
404: "Application or user not found",
|
||||
}
|
||||
)
|
||||
def get(self):
|
||||
system_features = FeatureService.get_system_features()
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
@ -94,9 +104,6 @@ class PassportResource(Resource):
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(PassportResource, "/passport")
|
||||
|
||||
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None):
|
||||
"""
|
||||
Decode the enterprise user session from the Authorization header.
|
||||
|
||||
@ -10,16 +10,44 @@ from controllers.common.errors import (
|
||||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@web_ns.route("/remote-files/<path:url>")
|
||||
class RemoteFileInfoApi(WebApiResource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
@web_ns.doc("get_remote_file_info")
|
||||
@web_ns.doc(description="Get information about a remote file")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Remote file information retrieved successfully",
|
||||
400: "Bad request - invalid URL",
|
||||
404: "Remote file not found",
|
||||
500: "Failed to fetch remote file",
|
||||
}
|
||||
)
|
||||
@marshal_with(build_remote_file_info_model(web_ns))
|
||||
def get(self, app_model, end_user, url):
|
||||
"""Get information about a remote file.
|
||||
|
||||
Retrieves basic information about a file located at a remote URL,
|
||||
including content type and content length.
|
||||
|
||||
Args:
|
||||
app_model: The associated application model
|
||||
end_user: The end user making the request
|
||||
url: URL-encoded path to the remote file
|
||||
|
||||
Returns:
|
||||
dict: Remote file information including type and length
|
||||
|
||||
Raises:
|
||||
HTTPException: If the remote file cannot be accessed
|
||||
"""
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
@ -32,9 +60,42 @@ class RemoteFileInfoApi(WebApiResource):
|
||||
}
|
||||
|
||||
|
||||
@web_ns.route("/remote-files/upload")
|
||||
class RemoteFileUploadApi(WebApiResource):
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
def post(self, app_model, end_user): # Add app_model and end_user parameters
|
||||
@web_ns.doc("upload_remote_file")
|
||||
@web_ns.doc(description="Upload a file from a remote URL")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
201: "Remote file uploaded successfully",
|
||||
400: "Bad request - invalid URL or parameters",
|
||||
413: "File too large",
|
||||
415: "Unsupported file type",
|
||||
500: "Failed to fetch remote file",
|
||||
}
|
||||
)
|
||||
@marshal_with(build_file_with_signed_url_model(web_ns))
|
||||
def post(self, app_model, end_user):
|
||||
"""Upload a file from a remote URL.
|
||||
|
||||
Downloads a file from the provided remote URL and uploads it
|
||||
to the platform storage for use in web applications.
|
||||
|
||||
Args:
|
||||
app_model: The associated application model
|
||||
end_user: The end user making the request
|
||||
|
||||
JSON Parameters:
|
||||
url: The remote URL to download the file from (required)
|
||||
|
||||
Returns:
|
||||
dict: File information including ID, signed URL, and metadata
|
||||
int: HTTP status code 201 for success
|
||||
|
||||
Raises:
|
||||
RemoteFileUploadError: Failed to fetch file from remote URL
|
||||
FileTooLargeError: File exceeds size limit
|
||||
UnsupportedFileTypeError: File type not supported
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("url", type=str, required=True, help="URL is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -30,6 +30,10 @@ class SavedMessageListApi(WebApiResource):
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
|
||||
post_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
@ -42,6 +46,7 @@ class SavedMessageListApi(WebApiResource):
|
||||
|
||||
return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
|
||||
|
||||
@marshal_with(post_response_fields)
|
||||
def post(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
@ -59,6 +64,11 @@ class SavedMessageListApi(WebApiResource):
|
||||
|
||||
|
||||
class SavedMessageApi(WebApiResource):
|
||||
delete_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@marshal_with(delete_response_fields)
|
||||
def delete(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
|
||||
@ -53,6 +53,18 @@ class AppSiteApi(WebApiResource):
|
||||
"custom_config": fields.Raw(attribute="custom_config"),
|
||||
}
|
||||
|
||||
@api.doc("Get App Site Info")
|
||||
@api.doc(description="Retrieve app site information and configuration.")
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "App Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(app_fields)
|
||||
def get(self, app_model, end_user):
|
||||
"""Retrieve app site info."""
|
||||
|
||||
@ -30,6 +30,24 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunApi(WebApiResource):
|
||||
@api.doc("Run Workflow")
|
||||
@api.doc(description="Execute a workflow with provided inputs and files.")
|
||||
@api.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the workflow", "type": "object", "required": True},
|
||||
"files": {"description": "Files to be processed by the workflow", "type": "array", "required": False},
|
||||
}
|
||||
)
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "App Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""
|
||||
Run workflow
|
||||
@ -62,11 +80,28 @@ class WorkflowRunApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class WorkflowTaskStopApi(WebApiResource):
|
||||
@api.doc("Stop Workflow Task")
|
||||
@api.doc(description="Stop a running workflow task.")
|
||||
@api.doc(
|
||||
params={
|
||||
"task_id": {"description": "Task ID to stop", "type": "string", "required": True},
|
||||
}
|
||||
)
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Task Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
|
||||
@ -1,3 +1,16 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
|
||||
class MoreLikeThisConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class AppConfigModel(BaseModel):
|
||||
more_like_this: MoreLikeThisConfig = Field(default_factory=MoreLikeThisConfig)
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class MoreLikeThisConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
@ -6,31 +19,14 @@ class MoreLikeThisConfigManager:
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
more_like_this = False
|
||||
more_like_this_dict = config.get("more_like_this")
|
||||
if more_like_this_dict:
|
||||
if more_like_this_dict.get("enabled"):
|
||||
more_like_this = True
|
||||
|
||||
return more_like_this
|
||||
validated_config, _ = cls.validate_and_set_defaults(config)
|
||||
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for more like this feature
|
||||
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("more_like_this"):
|
||||
config["more_like_this"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["more_like_this"], dict):
|
||||
raise ValueError("more_like_this must be of dict type")
|
||||
|
||||
if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]:
|
||||
config["more_like_this"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["more_like_this"]["enabled"], bool):
|
||||
raise ValueError("enabled in more_like_this must be of boolean type")
|
||||
|
||||
return config, ["more_like_this"]
|
||||
try:
|
||||
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
|
||||
)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
@ -143,6 +144,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
self._workflow_response_converter = WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
@ -373,7 +375,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle node succeeded events."""
|
||||
# Record files if it's an answer node or end node
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END, NodeType.LLM]:
|
||||
self._recorded_files.extend(
|
||||
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
@ -896,7 +898,14 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||
message = self._get_message(session=session)
|
||||
message.answer = self._task_state.answer
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
answer_text = self._task_state.answer
|
||||
if self._recorded_files:
|
||||
# Remove markdown image links since we're storing files separately
|
||||
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
|
||||
|
||||
message.answer = answer_text
|
||||
message.updated_at = naive_utc_now()
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
@ -8,6 +8,8 @@ from core.app.entities.task_entities import AppBlockingResponse, AppStreamRespon
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppGenerateResponseConverter(ABC):
|
||||
_blocking_response_type: type[AppBlockingResponse]
|
||||
@ -120,7 +122,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
if data:
|
||||
data.setdefault("message", getattr(e, "description", str(e)))
|
||||
else:
|
||||
logging.error(e)
|
||||
logger.error(e)
|
||||
data = {
|
||||
"code": "internal_server_error",
|
||||
"message": "Internal Server Error, please contact support.",
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, final
|
||||
|
||||
@ -14,6 +13,7 @@ from core.workflow.repositories.draft_variable_repository import (
|
||||
NoopDraftVariableSaver,
|
||||
)
|
||||
from factories import file_factory
|
||||
from libs.orjson import orjson_dumps
|
||||
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -174,7 +174,7 @@ class BaseAppGenerator:
|
||||
def gen():
|
||||
for message in generator:
|
||||
if isinstance(message, Mapping | dict):
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
yield f"data: {orjson_dumps(message)}\n\n"
|
||||
else:
|
||||
yield f"event: {message}\n\n"
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
@ -53,9 +52,7 @@ from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
|
||||
@ -64,8 +61,10 @@ class WorkflowResponseConverter:
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
user: Union[Account, EndUser],
|
||||
) -> None:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._user = user
|
||||
|
||||
def workflow_start_to_stream_response(
|
||||
self,
|
||||
@ -92,27 +91,21 @@ class WorkflowResponseConverter:
|
||||
workflow_execution: WorkflowExecution,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
created_by = None
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
|
||||
assert workflow_run is not None
|
||||
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
|
||||
stmt = select(Account).where(Account.id == workflow_run.created_by)
|
||||
account = session.scalar(stmt)
|
||||
if account:
|
||||
created_by = {
|
||||
"id": account.id,
|
||||
"name": account.name,
|
||||
"email": account.email,
|
||||
}
|
||||
elif workflow_run.created_by_role == CreatorUserRole.END_USER:
|
||||
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
|
||||
end_user = session.scalar(stmt)
|
||||
if end_user:
|
||||
created_by = {
|
||||
"id": end_user.id,
|
||||
"user": end_user.session_id,
|
||||
}
|
||||
|
||||
user = self._user
|
||||
if isinstance(user, Account):
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}
|
||||
elif isinstance(user, EndUser):
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"user": user.session_id,
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
|
||||
# Handle the case where finished_at is None by using current time as default
|
||||
finished_at_timestamp = (
|
||||
|
||||
@ -131,6 +131,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
self._workflow_response_converter = WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
)
|
||||
|
||||
self._application_generate_entity = application_generate_entity
|
||||
|
||||
@ -118,7 +118,7 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
duration: Optional[float] = None
|
||||
@ -201,7 +201,7 @@ class QueueLoopNextEvent(AppQueueEvent):
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current loop
|
||||
duration: Optional[float] = None
|
||||
@ -382,7 +382,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
"""iteration run in parallel mode run id"""
|
||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
class TaskPipilineError(ValueError):
|
||||
class TaskPipelineError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class RecordNotFoundError(TaskPipilineError):
|
||||
class RecordNotFoundError(TaskPipelineError):
|
||||
def __init__(self, record_name: str, record_id: str):
|
||||
super().__init__(f"{record_name} with id {record_id} not found")
|
||||
|
||||
|
||||
@ -32,6 +32,8 @@ from extensions.ext_database import db
|
||||
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageCycleManager:
|
||||
def __init__(
|
||||
@ -98,7 +100,7 @@ class MessageCycleManager:
|
||||
conversation.name = name
|
||||
except Exception as e:
|
||||
if dify_config.DEBUG:
|
||||
logging.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
pass
|
||||
|
||||
db.session.merge(conversation)
|
||||
|
||||
@ -19,6 +19,7 @@ class ModelStatus(Enum):
|
||||
QUOTA_EXCEEDED = "quota-exceeded"
|
||||
NO_PERMISSION = "no-permission"
|
||||
DISABLED = "disabled"
|
||||
CREDENTIAL_REMOVED = "credential-removed"
|
||||
|
||||
|
||||
class SimpleModelProviderEntity(BaseModel):
|
||||
@ -54,6 +55,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
|
||||
|
||||
status: ModelStatus
|
||||
load_balancing_enabled: bool = False
|
||||
has_invalid_load_balancing_configs: bool = False
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
"""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user